Example #1
0
class TestCardinal:
    inverse_normalizer_en = InverseNormalizer(lang='en') if PYNINI_AVAILABLE else None

    @parameterized.expand(parse_test_case_file('en/data_inverse_text_normalization/test_cases_cardinal.txt'))
    @pytest.mark.skipif(
        not PYNINI_AVAILABLE, reason="`pynini` not installed, please install via nemo_text_processing/setup.sh"
    )
    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    def test_denorm(self, test_input, expected):
        pred = self.inverse_normalizer_en.inverse_normalize(test_input, verbose=False)
        assert pred == expected

    normalizer_en = Normalizer(input_case='cased', lang='en') if PYNINI_AVAILABLE else None
    normalizer_with_audio_en = NormalizerWithAudio(input_case='cased', lang='en') if PYNINI_AVAILABLE else None

    @parameterized.expand(parse_test_case_file('en/data_text_normalization/test_cases_cardinal.txt'))
    @pytest.mark.skipif(
        not PYNINI_AVAILABLE, reason="`pynini` not installed, please install via nemo_text_processing/setup.sh"
    )
    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    def test_norm(self, test_input, expected):
        pred = self.normalizer_en.normalize(test_input, verbose=False)
        assert pred == expected
        pred_non_deterministic = self.normalizer_with_audio_en.normalize(test_input, n_tagged=100)
        assert expected in pred_non_deterministic
Example #2
0
class TestBoundary:

    normalizer_en = (Normalizer(input_case='cased',
                                lang='en',
                                cache_dir=CACHE_DIR,
                                overwrite_cache=False)
                     if PYNINI_AVAILABLE else None)
    normalizer_with_audio_en = (NormalizerWithAudio(input_case='cased',
                                                    lang='en',
                                                    cache_dir=CACHE_DIR,
                                                    overwrite_cache=False)
                                if PYNINI_AVAILABLE and CACHE_DIR else None)

    @parameterized.expand(
        parse_test_case_file(
            'en/data_text_normalization/test_cases_boundary.txt'))
    @pytest.mark.skipif(
        not PYNINI_AVAILABLE,
        reason=
        "`pynini` not installed, please install via nemo_text_processing/setup.sh"
    )
    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    def test_norm(self, test_input, expected):
        pred = self.normalizer_en.normalize(test_input, verbose=False)
        assert pred == expected

        if self.normalizer_with_audio_en:
            pred_non_deterministic = self.normalizer_with_audio_en.normalize(
                test_input, n_tagged=30, punct_post_process=False)
            assert expected in pred_non_deterministic
Example #3
0
class TestBoundary:

    normalizer = Normalizer(input_case='cased') if PYNINI_AVAILABLE else None
    normalizer_with_audio = NormalizerWithAudio(
        input_case='cased') if PYNINI_AVAILABLE else None

    @parameterized.expand(
        parse_test_case_file('data_text_normalization/test_cases_boundary.txt')
    )
    @pytest.mark.skipif(
        not PYNINI_AVAILABLE,
        reason=
        "`pynini` not installed, please install via nemo_text_processing/setup.sh"
    )
    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    def test_norm(self, test_input, expected):
        pred = self.normalizer.normalize(test_input, verbose=False)
        assert pred == expected
        pred_non_deterministic = self.normalizer_with_audio.normalize(
            test_input,
            n_tagged=100,
            punct_pre_process=False,
            punct_post_process=False)
        assert expected in pred_non_deterministic
Example #4
0
class TestDate:
    inverse_normalizer = InverseNormalizer() if PYNINI_AVAILABLE else None

    @parameterized.expand(
        parse_test_case_file(
            'data_inverse_text_normalization/test_cases_date.txt'))
    @pytest.mark.skipif(
        not PYNINI_AVAILABLE,
        reason=
        "`pynini` not installed, please install via nemo_text_processing/setup.sh"
    )
    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    def test_denorm(self, test_input, expected):
        pred = self.inverse_normalizer.inverse_normalize(test_input,
                                                         verbose=False)
        assert pred == expected

    normalizer = Normalizer(input_case='cased') if PYNINI_AVAILABLE else None
    normalizer_with_audio = NormalizerWithAudio(
        input_case='cased') if PYNINI_AVAILABLE else None

    @parameterized.expand(
        parse_test_case_file('data_text_normalization/test_cases_date.txt'))
    @pytest.mark.skipif(
        not PYNINI_AVAILABLE,
        reason=
        "`pynini` not installed, please install via nemo_text_processing/setup.sh"
    )
    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    def test_norm_uncased(self, test_input, expected):
        pred = self.normalizer.normalize(test_input, verbose=False)
        assert pred == expected
        pred_non_deterministic = self.normalizer_with_audio.normalize(
            test_input, n_tagged=100)
        assert expected in pred_non_deterministic

    normalizer_uppercased = Normalizer(
        input_case='cased') if PYNINI_AVAILABLE else None
    cases_uppercased = {
        "Aug. 8": "august eighth",
        "8 Aug.": "the eighth of august",
        "aug. 8": "august eighth"
    }

    @parameterized.expand(cases_uppercased.items())
    @pytest.mark.skipif(
        not PYNINI_AVAILABLE,
        reason=
        "`pynini` not installed, please install via nemo_text_processing/setup.sh"
    )
    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    def test_norm_cased(self, test_input, expected):
        pred = self.normalizer_uppercased.normalize(test_input, verbose=False)
        assert pred == expected
        pred_non_deterministic = self.normalizer_with_audio.normalize(
            test_input, n_tagged=100)
        assert expected in pred_non_deterministic
Example #5
0
class TestSpecialText:

    normalizer_en = (Normalizer(input_case='cased',
                                lang='en',
                                cache_dir=CACHE_DIR,
                                overwrite_cache=False)
                     if PYNINI_AVAILABLE else None)

    normalizer_with_audio_en = (NormalizerWithAudio(input_case='cased',
                                                    lang='en',
                                                    cache_dir=CACHE_DIR,
                                                    overwrite_cache=False)
                                if PYNINI_AVAILABLE and RUN_AUDIO_BASED_TESTS
                                else None)

    @parameterized.expand(
        parse_test_case_file(
            'en/data_text_normalization/test_cases_special_text.txt'))
    @pytest.mark.skipif(
        not PYNINI_AVAILABLE,
        reason=
        "`pynini` not installed, please install via nemo_text_processing/setup.sh"
    )
    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    def test_norm(self, test_input, expected):
        pred = self.normalizer_en.normalize(test_input, verbose=False)
        assert pred == expected
Example #6
0
class TestRoman:
    normalizer_en = (Normalizer(input_case='cased',
                                lang='en',
                                cache_dir=CACHE_DIR,
                                overwrite_cache=False)
                     if PYNINI_AVAILABLE else None)
    normalizer_with_audio_en = (NormalizerWithAudio(input_case='cased',
                                                    lang='en',
                                                    cache_dir=CACHE_DIR,
                                                    overwrite_cache=False)
                                if PYNINI_AVAILABLE and RUN_AUDIO_BASED_TESTS
                                else None)

    # address is tagged by the measure class
    @parameterized.expand(
        parse_test_case_file('en/data_text_normalization/test_cases_roman.txt')
    )
    @pytest.mark.skipif(
        not PYNINI_AVAILABLE,
        reason=
        "`pynini` not installed, please install via nemo_text_processing/setup.sh"
    )
    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    def test_norm(self, test_input, expected):
        # pred = self.normalizer_en.normalize(test_input, verbose=False)
        # assert pred == expected
        #
        # if self.normalizer_with_audio_en:
        #     pred_non_deterministic = self.normalizer_with_audio_en.normalize(
        #         test_input, n_tagged=30, punct_post_process=False,
        #     )
        #     assert expected in pred_non_deterministic
        pass
class TestNormalizeWithAudio:

    normalizer_de = (NormalizerWithAudio(input_case='cased',
                                         lang='de',
                                         cache_dir=CACHE_DIR,
                                         overwrite_cache=False)
                     if PYNINI_AVAILABLE else None)

    @parameterized.expand(
        get_test_cases_multiple(
            'de/data_text_normalization/test_cases_normalize_with_audio.txt'))
    @pytest.mark.skipif(
        not PYNINI_AVAILABLE,
        reason=
        "`pynini` not installed, please install via nemo_text_processing/setup.sh"
    )
    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    def test_norm(self, test_input, expected):
        pred = self.normalizer_de.normalize(test_input,
                                            n_tagged=1000,
                                            punct_post_process=False)
        print(expected)
        print("pred")
        print(pred)
        assert len(set(pred).intersection(set(expected))) == len(
            expected), f'missing: {set(expected).difference(set(pred))}'
Example #8
0
 def setup_cgs(self, cfg: DictConfig):
     """
     Setup covering grammars (if enabled).
     :param cfg: Configs of the decoder model.
     """
     self.use_cg = True
     self.neural_confidence_threshold = cfg.get('neural_confidence_threshold', 0.99)
     self.n_tagged = cfg.get('n_tagged', 1)
     input_case = 'cased'  # input_case is cased by default
     if hasattr(self._tokenizer, 'do_lower_case') and self._tokenizer.do_lower_case:
         input_case = 'lower_cased'
     if not PYNINI_AVAILABLE:
         raise Exception(
             "`pynini` is not installed ! \n"
             "Please run the `nemo_text_processing/setup.sh` script"
             "prior to usage of this toolkit."
         )
     self.cg_normalizer = NormalizerWithAudio(input_case=input_case, lang=self.lang)
Example #9
0
class TestNormalizeWithAudio:

    normalizer = NormalizerWithAudio(input_case='cased') if PYNINI_AVAILABLE else None

    @parameterized.expand(get_test_cases_multiple('data_text_normalization/test_cases_normalize_with_audio.txt'))
    @pytest.mark.skipif(
        not PYNINI_AVAILABLE, reason="`pynini` not installed, please install via nemo_text_processing/setup.sh"
    )
    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    def test_norm(self, test_input, expected):
        pred = self.normalizer.normalize(test_input, n_tagged=700)
        assert len(set(pred).intersection(set(expected))) == len(expected), f'pred: {pred}'
Example #10
0
class TestWord:
    inverse_normalizer_en = (InverseNormalizer(
        lang='en', cache_dir=CACHE_DIR, overwrite_cache=False)
                             if PYNINI_AVAILABLE else None)

    @parameterized.expand(
        parse_test_case_file(
            'en/data_inverse_text_normalization/test_cases_word.txt'))
    @pytest.mark.skipif(
        not PYNINI_AVAILABLE,
        reason=
        "`pynini` not installed, please install via nemo_text_processing/setup.sh"
    )
    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    def test_denorm(self, test_input, expected):
        pred = self.inverse_normalizer_en.inverse_normalize(test_input,
                                                            verbose=False)
        assert pred == expected

    normalizer_en = (Normalizer(input_case='cased',
                                lang='en',
                                cache_dir=CACHE_DIR,
                                overwrite_cache=False,
                                post_process=True)
                     if PYNINI_AVAILABLE else None)
    normalizer_with_audio_en = (NormalizerWithAudio(input_case='cased',
                                                    lang='en',
                                                    cache_dir=CACHE_DIR,
                                                    overwrite_cache=False)
                                if PYNINI_AVAILABLE and RUN_AUDIO_BASED_TESTS
                                else None)

    @parameterized.expand(
        parse_test_case_file('en/data_text_normalization/test_cases_word.txt'))
    @pytest.mark.skipif(
        not PYNINI_AVAILABLE,
        reason=
        "`pynini` not installed, please install via nemo_text_processing/setup.sh"
    )
    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    def test_norm(self, test_input, expected):
        pred = self.normalizer_en.normalize(test_input, verbose=False)
        assert pred == expected, f"input: {test_input} != {expected}"

        if self.normalizer_with_audio_en:
            pred_non_deterministic = self.normalizer_with_audio_en.normalize(
                test_input, n_tagged=3, punct_post_process=False)
            assert expected in pred_non_deterministic, f"input: {test_input}"
Example #11
0
class TestWhitelist:
    inverse_normalizer_en = InverseNormalizer(lang='en') if PYNINI_AVAILABLE else None

    @parameterized.expand(parse_test_case_file('en/data_inverse_text_normalization/test_cases_whitelist.txt'))
    @pytest.mark.skipif(
        not PYNINI_AVAILABLE, reason="`pynini` not installed, please install via nemo_text_processing/setup.sh"
    )
    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    def test_denorm(self, test_input, expected):
        pred = self.inverse_normalizer_en.inverse_normalize(test_input, verbose=False)
        assert pred == expected

    normalizer_en = Normalizer(input_case='lower_cased') if PYNINI_AVAILABLE else None
    normalizer_with_audio_en = NormalizerWithAudio(input_case='cased', lang='en') if PYNINI_AVAILABLE else None

    @parameterized.expand(parse_test_case_file('en/data_text_normalization/test_cases_whitelist.txt'))
    @pytest.mark.skipif(
        not PYNINI_AVAILABLE, reason="`pynini` not installed, please install via nemo_text_processing/setup.sh"
    )
    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    def test_norm(self, test_input, expected):
        pred = self.normalizer_en.normalize(test_input, verbose=False)
        assert pred == expected
        pred_non_deterministic = self.normalizer_with_audio_en.normalize(test_input, n_tagged=100)
        assert expected in pred_non_deterministic

    normalizer_uppercased = Normalizer(input_case='cased', lang='en') if PYNINI_AVAILABLE else None
    cases_uppercased = {"Dr. Evil": "doctor Evil", "No. 4": "number four", "dr. Evil": "dr. Evil", "no. 4": "no. four"}

    @parameterized.expand(cases_uppercased.items())
    @pytest.mark.skipif(
        not PYNINI_AVAILABLE, reason="`pynini` not installed, please install via nemo_text_processing/setup.sh"
    )
    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    def test_norm_cased(self, test_input, expected):
        pred = self.normalizer_uppercased.normalize(test_input, verbose=False)
        assert pred == expected
        pred_non_deterministic = self.normalizer_with_audio_en.normalize(test_input, n_tagged=100)
        assert expected in pred_non_deterministic
Example #12
0
class TestRuNormalizeWithAudio:

    normalizer = NormalizerWithAudio(input_case='cased', lang='ru', cache_dir=CACHE_DIR) if PYNINI_AVAILABLE else None
Example #13
0
class TestRuNormalizeWithAudio:

    normalizer = NormalizerWithAudio(
        input_case='cased', lang='ru',
        cache_dir=CACHE_DIR) if PYNINI_AVAILABLE else None
    N_TAGGED = 3000

    @parameterized.expand(
        parse_test_case_file(
            'ru/data_text_normalization/test_cases_cardinal.txt'))
    @pytest.mark.skipif(
        not PYNINI_AVAILABLE,
        reason=
        "`pynini` not installed, please install via nemo_text_processing/setup.sh"
    )
    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    def test_norm_cardinal(self, expected, test_input):
        preds = self.normalizer.normalize(test_input, n_tagged=self.N_TAGGED)
        assert expected in preds

    @parameterized.expand(
        parse_test_case_file(
            'ru/data_text_normalization/test_cases_ordinal.txt'))
    @pytest.mark.skipif(
        not PYNINI_AVAILABLE,
        reason=
        "`pynini` not installed, please install via nemo_text_processing/setup.sh"
    )
    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    def test_norm_ordinal(self, expected, test_input):
        preds = self.normalizer.normalize(test_input, n_tagged=self.N_TAGGED)
        assert expected in preds

    @parameterized.expand(
        parse_test_case_file(
            'ru/data_text_normalization/test_cases_decimal.txt'))
    @pytest.mark.skipif(
        not PYNINI_AVAILABLE,
        reason=
        "`pynini` not installed, please install via nemo_text_processing/setup.sh"
    )
    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    def test_norm_decimal(self, expected, test_input):
        preds = self.normalizer.normalize(test_input, n_tagged=5000)
        assert expected in preds

    @parameterized.expand(
        parse_test_case_file(
            'ru/data_text_normalization/test_cases_measure.txt'))
    @pytest.mark.skipif(
        not PYNINI_AVAILABLE,
        reason=
        "`pynini` not installed, please install via nemo_text_processing/setup.sh"
    )
    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    def test_norm_measure(self, expected, test_input):
        preds = self.normalizer.normalize(test_input, n_tagged=self.N_TAGGED)
        assert expected in preds

    @parameterized.expand(
        parse_test_case_file('ru/data_text_normalization/test_cases_date.txt'))
    @pytest.mark.skipif(
        not PYNINI_AVAILABLE,
        reason=
        "`pynini` not installed, please install via nemo_text_processing/setup.sh"
    )
    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    def test_norm_date(self, expected, test_input):
        preds = self.normalizer.normalize(test_input, n_tagged=-1)
        assert expected in preds, expected not in preds

    @parameterized.expand(
        parse_test_case_file(
            'ru/data_text_normalization/test_cases_telephone.txt'))
    @pytest.mark.skipif(
        not PYNINI_AVAILABLE,
        reason=
        "`pynini` not installed, please install via nemo_text_processing/setup.sh"
    )
    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    def test_norm_telephone(self, expected, test_input):
        preds = self.normalizer.normalize(test_input, n_tagged=-1)
        assert expected in preds, expected not in preds

    @parameterized.expand(
        parse_test_case_file('ru/data_text_normalization/test_cases_money.txt')
    )
    @pytest.mark.skipif(
        not PYNINI_AVAILABLE,
        reason=
        "`pynini` not installed, please install via nemo_text_processing/setup.sh"
    )
    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    def test_norm_money(self, expected, test_input):
        preds = self.normalizer.normalize(test_input, n_tagged=-1)
        assert expected in preds, expected not in preds

    @parameterized.expand(
        parse_test_case_file('ru/data_text_normalization/test_cases_time.txt'))
    @pytest.mark.skipif(
        not PYNINI_AVAILABLE,
        reason=
        "`pynini` not installed, please install via nemo_text_processing/setup.sh"
    )
    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    def test_norm_time(self, expected, test_input):
        preds = self.normalizer.normalize(test_input, n_tagged=-1)
        assert expected in preds, expected not in preds

    @parameterized.expand(
        parse_test_case_file(
            'ru/data_text_normalization/test_cases_electronic.txt'))
    @pytest.mark.skipif(
        not PYNINI_AVAILABLE,
        reason=
        "`pynini` not installed, please install via nemo_text_processing/setup.sh"
    )
    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    def test_norm_electronic(self, expected, test_input):
        preds = self.normalizer.normalize(test_input, n_tagged=self.N_TAGGED)
        assert expected in preds

    @parameterized.expand(
        parse_test_case_file(
            'ru/data_text_normalization/test_cases_whitelist.txt'))
    @pytest.mark.skipif(
        not PYNINI_AVAILABLE,
        reason=
        "`pynini` not installed, please install via nemo_text_processing/setup.sh"
    )
    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    def test_norm_whitelist(self, expected, test_input):
        preds = self.normalizer.normalize(test_input, n_tagged=self.N_TAGGED)
        assert expected in preds

    @parameterized.expand(
        parse_test_case_file('ru/data_text_normalization/test_cases_word.txt'))
    @pytest.mark.skipif(
        not PYNINI_AVAILABLE,
        reason=
        "`pynini` not installed, please install via nemo_text_processing/setup.sh"
    )
    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    def test_norm_word(self, expected, test_input):
        preds = self.normalizer.normalize(test_input, n_tagged=self.N_TAGGED)
        assert expected in preds
Example #14
0
class DuplexDecoderModel(NLPModel):
    """
    Transformer-based (duplex) decoder model for TN/ITN.
    """
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable
        # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0
        self.world_size = 1
        if trainer is not None:
            self.world_size = trainer.num_nodes * trainer.num_gpus

        self._tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer)

        super().__init__(cfg=cfg, trainer=trainer)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(cfg.transformer)
        self.max_sequence_len = cfg.get('max_sequence_len',
                                        self._tokenizer.model_max_length)
        self.mode = cfg.get('mode', 'joint')

        self.transformer_name = cfg.transformer

        # Language
        self.lang = cfg.get('lang', None)

        # Covering Grammars
        self.cg_normalizer = None  # Default
        # We only support integrating with English TN covering grammars at the moment
        self.use_cg = cfg.get('use_cg',
                              False) and self.lang == constants.ENGLISH
        if self.use_cg:
            self.setup_cgs(cfg)

        # setup processor for detokenization
        self.processor = MosesProcessor(lang_id=self.lang)

    # Setup covering grammars (if enabled)
    def setup_cgs(self, cfg: DictConfig):
        """
        Setup covering grammars (if enabled).
        :param cfg: Configs of the decoder model.
        """
        self.use_cg = True
        self.neural_confidence_threshold = cfg.get(
            'neural_confidence_threshold', 0.99)
        self.n_tagged = cfg.get('n_tagged', 1)
        input_case = 'cased'  # input_case is cased by default
        if hasattr(self._tokenizer,
                   'do_lower_case') and self._tokenizer.do_lower_case:
            input_case = 'lower_cased'
        if not PYNINI_AVAILABLE:
            raise Exception(
                "`pynini` is not installed ! \n"
                "Please run the `nemo_text_processing/setup.sh` script"
                "prior to usage of this toolkit.")
        self.cg_normalizer = NormalizerWithAudio(input_case=input_case,
                                                 lang=self.lang)

    # Training
    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        # tarred dataset contains batches, and the first dimension of size 1 added by the DataLoader
        # (batch_size is set to 1) is redundant
        if batch['input_ids'].ndim == 3:
            batch = {k: v.squeeze(dim=0) for k, v in batch.items()}

        # Apply Transformer
        outputs = self.model(
            input_ids=batch['input_ids'],
            decoder_input_ids=batch['decoder_input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['labels'],
        )
        train_loss = outputs.loss

        lr = self._optimizer.param_groups[0]['lr']
        self.log('train_loss', train_loss)
        self.log('lr', lr, prog_bar=True)
        return {'loss': train_loss, 'lr': lr}

    # Validation and Testing
    def validation_step(self, batch, batch_idx, dataloader_idx=0, split="val"):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        # Apply Transformer
        outputs = self.model(
            input_ids=batch['input_ids'],
            decoder_input_ids=batch['decoder_input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['labels'],
        )
        val_loss = outputs.loss

        labels_str = self._tokenizer.batch_decode(
            torch.ones_like(batch['labels']) *
            ((batch['labels'] == -100) * 100) + batch['labels'],
            skip_special_tokens=True,
        )
        generated_texts, _, _ = self._generate_predictions(
            input_ids=batch['input_ids'], model_max_len=self.max_sequence_len)

        input_centers = self._tokenizer.batch_decode(batch['input_center'],
                                                     skip_special_tokens=True)
        direction = [x[0].item() for x in batch['direction']]
        direction_str = [constants.DIRECTIONS_ID_TO_NAME[x] for x in direction]
        # apply post_processing
        generated_texts = self.postprocess_output_spans(
            input_centers, generated_texts, direction_str)
        results = defaultdict(int)
        for idx, class_id in enumerate(batch['semiotic_class_id']):
            direction = constants.TASK_ID_TO_MODE[batch['direction'][idx]
                                                  [0].item()]
            class_name = self._val_id_to_class[dataloader_idx][
                class_id[0].item()]
            results[f"correct_{class_name}_{direction}"] += torch.tensor(
                labels_str[idx] == generated_texts[idx],
                dtype=torch.int).to(self.device)
            results[f"total_{class_name}_{direction}"] += torch.tensor(1).to(
                self.device)

        results[f"{split}_loss"] = val_loss
        return dict(results)

    def multi_validation_epoch_end(self,
                                   outputs: List,
                                   dataloader_idx=0,
                                   split="val"):
        """
        Called at the end of validation to aggregate outputs.

        Args:
            outputs: list of individual outputs of each validation step.
        """
        avg_loss = torch.stack([x[f'{split}_loss'] for x in outputs]).mean()

        # create a dictionary to store all the results
        results = {}
        directions = [constants.TN_MODE, constants.ITN_MODE
                      ] if self.mode == constants.JOINT_MODE else [self.mode]
        for class_name in self._val_class_to_id[dataloader_idx]:
            for direction in directions:
                results[f"correct_{class_name}_{direction}"] = 0
                results[f"total_{class_name}_{direction}"] = 0

        for key in results:
            count = [x[key] for x in outputs if key in x]
            count = torch.stack(count).sum(
            ) if len(count) > 0 else torch.tensor(0).to(self.device)
            results[key] = count

        all_results = defaultdict(list)

        if torch.distributed.is_initialized():
            world_size = torch.distributed.get_world_size()
            for ind in range(world_size):
                for key, v in results.items():
                    all_results[key].append(torch.empty_like(v))
            for key, v in results.items():
                torch.distributed.all_gather(all_results[key], v)
        else:
            for key, v in results.items():
                all_results[key].append(v)

        if not torch.distributed.is_initialized(
        ) or torch.distributed.get_rank() == 0:
            if split == "test":
                val_name = self._test_names[dataloader_idx].upper()
            else:
                val_name = self._validation_names[dataloader_idx].upper()
            final_results = defaultdict(int)
            for key, v in all_results.items():
                for _v in v:
                    final_results[key] += _v.item()

            accuracies = defaultdict(dict)
            for key, value in final_results.items():
                if "total_" in key:
                    _, class_name, mode = key.split('_')
                    correct = final_results[f"correct_{class_name}_{mode}"]
                    if value == 0:
                        accuracies[mode][class_name] = (0, correct, value)
                    else:
                        acc = round(correct / value * 100, 3)
                        accuracies[mode][class_name] = (acc, correct, value)

            for mode, values in accuracies.items():
                report = f"Accuracy {mode.upper()} task {val_name}:\n"
                report += '\n'.join([
                    get_formatted_string(
                        (class_name, f'{v[0]}% ({v[1]}/{v[2]})'),
                        str_max_len=24) for class_name, v in values.items()
                ])
                # calculate average across all classes
                all_total = 0
                all_correct = 0
                for _, class_values in values.items():
                    _, correct, total = class_values
                    all_correct += correct
                    all_total += total
                all_acc = round(
                    (all_correct / all_total) * 100, 3) if all_total > 0 else 0
                report += '\n' + get_formatted_string(
                    ('AVG', f'{all_acc}% ({all_correct}/{all_total})'),
                    str_max_len=24)
                logging.info(report)
                accuracies[mode]['AVG'] = [all_acc]

        self.log(f'{split}_loss', avg_loss)
        if self.trainer.is_global_zero:
            for mode in accuracies:
                for class_name, values in accuracies[mode].items():
                    self.log(
                        f'{val_name}_{mode.upper()}_acc_{class_name.upper()}',
                        values[0],
                        rank_zero_only=True)
        return {
            f'{split}_loss': avg_loss,
        }

    def test_step(self, batch, batch_idx, dataloader_idx: int = 0):
        """
        Lightning calls this inside the test loop with the data from the test dataloader
        passed in as `batch`.
        """
        return self.validation_step(batch,
                                    batch_idx,
                                    dataloader_idx,
                                    split="test")

    def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0):
        """
        Called at the end of test to aggregate outputs.
        outputs: list of individual outputs of each test step.
        """
        return self.multi_validation_epoch_end(outputs,
                                               dataloader_idx,
                                               split="test")

    @torch.no_grad()
    def _generate_predictions(self,
                              input_ids: torch.Tensor,
                              model_max_len: int = 512):
        """
        Generates predictions
        """
        outputs = self.model.generate(input_ids,
                                      output_scores=True,
                                      return_dict_in_generate=True,
                                      max_length=model_max_len)

        generated_ids, sequence_toks_scores = outputs['sequences'], outputs[
            'scores']
        generated_texts = self._tokenizer.batch_decode(
            generated_ids, skip_special_tokens=True)

        return generated_texts, generated_ids, sequence_toks_scores

    # Functions for inference
    @torch.no_grad()
    def _infer(
        self,
        sents: List[List[str]],
        nb_spans: List[int],
        span_starts: List[List[int]],
        span_ends: List[List[int]],
        inst_directions: List[str],
    ):
        """ Main function for Inference
        Args:
            sents: A list of inputs tokenized by a basic tokenizer.
            nb_spans: A list of ints where each int indicates the number of semiotic spans in each input.
            span_starts: A list of lists where each list contains the starting locations of semiotic spans in an input.
            span_ends: A list of lists where each list contains the ending locations of semiotic spans in an input.
            inst_directions: A list of str where each str indicates the direction of the corresponding instance (i.e., INST_BACKWARD for ITN or INST_FORWARD for TN).

        Returns: A list of lists where each list contains the decoded spans for the corresponding input.
        """
        self.eval()

        if sum(nb_spans) == 0:
            return [[]] * len(sents)
        model, tokenizer = self.model, self._tokenizer
        try:
            model_max_len = model.config.n_positions
        except AttributeError:
            model_max_len = 512
        ctx_size = constants.DECODE_CTX_SIZE
        extra_id_0 = constants.EXTRA_ID_0
        extra_id_1 = constants.EXTRA_ID_1
        """
        Build all_inputs - extracted spans to be transformed by the decoder model
        Inputs for TN direction have "0" prefix, while the backward, ITN direction, has prefix "1"
        "input_centers" - List[str] - ground-truth labels for the span #TODO: rename
        """
        input_centers, input_dirs, all_inputs = [], [], []
        for ix, sent in enumerate(sents):
            cur_inputs = []
            for jx in range(nb_spans[ix]):
                cur_start = span_starts[ix][jx]
                cur_end = span_ends[ix][jx]
                ctx_left = sent[max(0, cur_start - ctx_size):cur_start]
                ctx_right = sent[cur_end + 1:cur_end + 1 + ctx_size]
                span_words = sent[cur_start:cur_end + 1]
                span_words_str = ' '.join(span_words)
                if is_url(span_words_str):
                    span_words_str = span_words_str.lower()
                input_centers.append(span_words_str)
                input_dirs.append(inst_directions[ix])
                # Build cur_inputs
                if inst_directions[ix] == constants.INST_BACKWARD:
                    cur_inputs = [constants.ITN_PREFIX]
                if inst_directions[ix] == constants.INST_FORWARD:
                    cur_inputs = [constants.TN_PREFIX]
                cur_inputs += ctx_left
                cur_inputs += [extra_id_0
                               ] + span_words_str.split(' ') + [extra_id_1]
                cur_inputs += ctx_right
                all_inputs.append(' '.join(cur_inputs))

        # Apply the decoding model
        batch = tokenizer(all_inputs, padding=True, return_tensors='pt')
        input_ids = batch['input_ids'].to(self.device)

        generated_texts, generated_ids, sequence_toks_scores = self._generate_predictions(
            input_ids=input_ids, model_max_len=model_max_len)

        # Use covering grammars (if enabled)
        if self.use_cg:
            # Compute sequence probabilities
            sequence_probs = torch.ones(len(all_inputs)).to(self.device)
            for ix, cur_toks_scores in enumerate(sequence_toks_scores):
                cur_generated_ids = generated_ids[:, ix + 1].tolist()
                cur_toks_probs = torch.nn.functional.softmax(cur_toks_scores,
                                                             dim=-1)
                # Compute selected_toks_probs
                selected_toks_probs = []
                for jx, _id in enumerate(cur_generated_ids):
                    if _id != self._tokenizer.pad_token_id:
                        selected_toks_probs.append(cur_toks_probs[jx, _id])
                    else:
                        selected_toks_probs.append(1)
                selected_toks_probs = torch.tensor(selected_toks_probs).to(
                    self.device)
                sequence_probs *= selected_toks_probs

            # For TN cases where the neural model is not confident, use CGs
            neural_confidence_threshold = self.neural_confidence_threshold
            for ix, (_dir, _input, _prob) in enumerate(
                    zip(input_dirs, input_centers, sequence_probs)):
                if _dir == constants.INST_FORWARD and _prob < neural_confidence_threshold:
                    if is_url(_input):
                        _input = _input.replace(' ',
                                                '')  # Remove spaces in URLs
                    try:
                        cg_outputs = self.cg_normalizer.normalize(
                            text=_input, verbose=False, n_tagged=self.n_tagged)
                        generated_texts[ix] = list(cg_outputs)[0]
                    except:  # if there is any exception, fall back to the input
                        generated_texts[ix] = _input

        # Post processing
        generated_texts = self.postprocess_output_spans(
            input_centers, generated_texts, input_dirs)

        # Prepare final_texts
        final_texts, span_ctx = [], 0
        for nb_span in nb_spans:
            cur_texts = []
            for i in range(nb_span):
                cur_texts.append(generated_texts[span_ctx])
                span_ctx += 1
            final_texts.append(cur_texts)

        return final_texts

    def postprocess_output_spans(self, input_centers: List[str],
                                 generated_spans: List[str],
                                 input_dirs: List[str]):
        """
        Post processing of the generated texts

        Args:
            input_centers: Input str (no special tokens or context)
            generated_spans: Generated spans
            input_dirs: task direction: constants.INST_BACKWARD or constants.INST_FORWARD

        Returns:
            Processing texts
        """
        en_greek_writtens = list(constants.EN_GREEK_TO_SPOKEN.keys())
        en_greek_spokens = list(constants.EN_GREEK_TO_SPOKEN.values())
        for ix, (_input,
                 _output) in enumerate(zip(input_centers, generated_spans)):
            if self.lang == constants.ENGLISH:
                # Handle URL
                if is_url(_input):
                    _output = _output.replace('http', ' h t t p ')
                    _output = _output.replace('/', ' slash ')
                    _output = _output.replace('.', ' dot ')
                    _output = _output.replace(':', ' colon ')
                    _output = _output.replace('-', ' dash ')
                    _output = _output.replace('_', ' underscore ')
                    _output = _output.replace('%', ' percent ')
                    _output = _output.replace('www', ' w w w ')
                    _output = _output.replace('ftp', ' f t p ')
                    generated_spans[ix] = ' '.join(wordninja.split(_output))
                    continue
                # Greek letters
                if _input in en_greek_writtens:
                    if input_dirs[ix] == constants.INST_FORWARD:
                        generated_spans[ix] = constants.EN_GREEK_TO_SPOKEN[
                            _input]
                if _input in en_greek_spokens:
                    if input_dirs[ix] == constants.INST_FORWARD:
                        generated_spans[ix] = _input
                    if input_dirs[ix] == constants.INST_BACKWARD:
                        generated_spans[ix] = constants.EN_SPOKEN_TO_GREEK[
                            _input]
        return generated_spans

    # Functions for processing data
    def setup_training_data(self, train_data_config: Optional[DictConfig]):
        if not train_data_config or not train_data_config.data_path:
            logging.info(
                f"Dataloader config or file_path for the train is missing, so no data loader for train is created!"
            )
            self.train_dataset, self._train_dl = None, None
            return
        self.train_dataset, self._train_dl = self._setup_dataloader_from_config(
            cfg=train_data_config, data_split="train")

    def setup_validation_data(self, val_data_config: Optional[DictConfig]):
        if not val_data_config or not val_data_config.data_path:
            logging.info(
                f"Dataloader config or file_path for the validation is missing, so no data loader for validation is created!"
            )
            self.validation_dataset, self._validation_dl = None, None
            return
        self.validation_dataset, self._validation_dl = self._setup_dataloader_from_config(
            cfg=val_data_config, data_split="val")

    def setup_multiple_validation_data(self,
                                       val_data_config: Union[DictConfig,
                                                              Dict] = None):
        if val_data_config is None:
            val_data_config = self._cfg.validation_ds
        return super().setup_multiple_validation_data(val_data_config)

    def setup_multiple_test_data(self,
                                 test_data_config: Union[DictConfig,
                                                         Dict] = None):
        if test_data_config is None:
            test_data_config = self._cfg.test_ds
        return super().setup_multiple_test_data(test_data_config)

    def setup_test_data(self, test_data_config: Optional[DictConfig]):
        if not test_data_config or test_data_config.data_path is None:
            logging.info(
                f"Dataloader config or file_path for the test is missing, so no data loader for test is created!"
            )
            self.test_dataset, self._test_dl = None, None
            return
        self.test_dataset, self._test_dl = self._setup_dataloader_from_config(
            cfg=test_data_config, data_split="test")

    def _setup_dataloader_from_config(self, cfg: DictConfig, data_split: str):
        logging.info(f"Creating {data_split} dataset")

        shuffle = cfg["shuffle"]

        if cfg.get("use_tarred_dataset", False):
            logging.info('Tarred dataset')
            metadata_file = cfg["tar_metadata_file"]
            if metadata_file is None or not os.path.exists(metadata_file):
                raise FileNotFoundError(
                    f"Trying to use tarred dataset but could not find {metadata_file}."
                )

            with open(metadata_file, "r") as f:
                metadata = json.load(f)
                num_batches = metadata["num_batches"]
                tar_files = os.path.join(os.path.dirname(metadata_file),
                                         metadata["text_tar_filepaths"])
            logging.info(f"Loading {tar_files}")

            dataset = TarredTextNormalizationDecoderDataset(
                text_tar_filepaths=tar_files,
                num_batches=num_batches,
                shuffle_n=cfg.get("tar_shuffle_n", 4 *
                                  cfg['batch_size']) if shuffle else 0,
                shard_strategy=cfg.get("shard_strategy", "scatter"),
                global_rank=self.global_rank,
                world_size=self.world_size,
            )

            dl = torch.utils.data.DataLoader(
                dataset=dataset,
                batch_size=1,
                sampler=None,
                num_workers=cfg.get("num_workers", 2),
                pin_memory=cfg.get("pin_memory", False),
                drop_last=cfg.get("drop_last", False),
            )
        else:
            input_file = cfg.data_path
            if not os.path.exists(input_file):
                raise ValueError(f"{input_file} not found.")

            dataset = TextNormalizationDecoderDataset(
                input_file=input_file,
                tokenizer=self._tokenizer,
                tokenizer_name=self.transformer_name,
                mode=self.mode,
                max_len=self.max_sequence_len,
                decoder_data_augmentation=cfg.get('decoder_data_augmentation',
                                                  False)
                if data_split == "train" else False,
                lang=self.lang,
                do_basic_tokenize=cfg.do_basic_tokenize,
                use_cache=cfg.get('use_cache', False),
                max_insts=cfg.get('max_insts', -1),
                do_tokenize=True,
            )

            # create and save class names to class_ids mapping for validation
            # (each validation set might have different classes)
            if data_split in ['val', 'test']:
                if not hasattr(self, "_val_class_to_id"):
                    self._val_class_to_id = []
                    self._val_id_to_class = []
                self._val_class_to_id.append(dataset.label_ids_semiotic)
                self._val_id_to_class.append(
                    {v: k
                     for k, v in dataset.label_ids_semiotic.items()})

            data_collator = DataCollatorForSeq2Seq(
                self._tokenizer,
                model=self.model,
                label_pad_token_id=constants.LABEL_PAD_TOKEN_ID,
                padding=True)
            dl = torch.utils.data.DataLoader(
                dataset=dataset,
                batch_size=cfg.batch_size,
                shuffle=shuffle,
                collate_fn=data_collator,
                num_workers=cfg.get("num_workers", 3),
                pin_memory=cfg.get("pin_memory", False),
                drop_last=cfg.get("drop_last", False),
            )

        return dataset, dl

    @classmethod
    def list_available_models(cls) -> Optional[PretrainedModelInfo]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
        Returns:
            List of available pre-trained models.
        """
        result = []
        return result
Example #15
0
class TestDate:
    inverse_normalizer_en = (InverseNormalizer(
        lang='en', cache_dir=CACHE_DIR, overwrite_cache=False)
                             if PYNINI_AVAILABLE else None)

    @parameterized.expand(
        parse_test_case_file(
            'en/data_inverse_text_normalization/test_cases_date.txt'))
    @pytest.mark.skipif(
        not PYNINI_AVAILABLE,
        reason=
        "`pynini` not installed, please install via nemo_text_processing/setup.sh"
    )
    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    def test_denorm(self, test_input, expected):
        pred = self.inverse_normalizer_en.inverse_normalize(test_input,
                                                            verbose=False)
        assert pred == expected

    normalizer_en = (Normalizer(input_case='cased',
                                lang='en',
                                cache_dir=CACHE_DIR,
                                overwrite_cache=False,
                                post_process=True)
                     if PYNINI_AVAILABLE else None)
    normalizer_with_audio_en = (NormalizerWithAudio(input_case='cased',
                                                    lang='en',
                                                    cache_dir=CACHE_DIR,
                                                    overwrite_cache=False)
                                if PYNINI_AVAILABLE and RUN_AUDIO_BASED_TESTS
                                else None)

    @parameterized.expand(
        parse_test_case_file('en/data_text_normalization/test_cases_date.txt'))
    @pytest.mark.skipif(
        not PYNINI_AVAILABLE,
        reason=
        "`pynini` not installed, please install via nemo_text_processing/setup.sh"
    )
    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    def test_norm_uncased(self, test_input, expected):
        pred = self.normalizer_en.normalize(test_input, verbose=False)
        assert pred == expected

        if self.normalizer_with_audio_en:
            pred_non_deterministic = self.normalizer_with_audio_en.normalize(
                test_input, punct_post_process=False, n_tagged=100)
            assert expected in pred_non_deterministic, f"INPUT: {test_input}"

    normalizer_uppercased = (Normalizer(input_case='cased',
                                        lang='en',
                                        cache_dir=CACHE_DIR,
                                        overwrite_cache=False)
                             if PYNINI_AVAILABLE else None)
    cases_uppercased = {
        "Aug. 8": "august eighth",
        "8 Aug.": "the eighth of august",
        "aug. 8": "august eighth"
    }

    @parameterized.expand(cases_uppercased.items())
    @pytest.mark.skipif(
        not PYNINI_AVAILABLE,
        reason=
        "`pynini` not installed, please install via nemo_text_processing/setup.sh"
    )
    @pytest.mark.run_only_on('CPU')
    @pytest.mark.unit
    def test_norm_cased(self, test_input, expected):
        pred = self.normalizer_uppercased.normalize(test_input, verbose=False)
        assert pred == expected

        if self.normalizer_with_audio_en:
            pred_non_deterministic = self.normalizer_with_audio_en.normalize(
                test_input, punct_post_process=False, n_tagged=30)
            assert expected in pred_non_deterministic
Example #16
0
class DuplexDecoderModel(NLPModel):
    """
    Transformer-based (duplex) decoder model for TN/ITN.
    """
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        self._tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer)
        super().__init__(cfg=cfg, trainer=trainer)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(cfg.transformer)
        self.transformer_name = cfg.transformer

        # Language
        self.lang = cfg.get('lang', None)

        # Covering Grammars
        self.cg_normalizer = None  # Default
        # We only support integrating with English TN covering grammars at the moment
        self.use_cg = cfg.get('use_cg',
                              False) and self.lang == constants.ENGLISH
        if self.use_cg:
            self.setup_cgs(cfg)

    # Setup covering grammars (if enabled)
    def setup_cgs(self, cfg: DictConfig):
        """
        Setup covering grammars (if enabled).
        :param cfg: Configs of the decoder model.
        """
        self.use_cg = True
        self.neural_confidence_threshold = cfg.get(
            'neural_confidence_threshold', 0.99)
        self.n_tagged = cfg.get('n_tagged', 1)
        input_case = 'cased'  # input_case is cased by default
        if hasattr(self._tokenizer,
                   'do_lower_case') and self._tokenizer.do_lower_case:
            input_case = 'lower_cased'
        if not PYNINI_AVAILABLE:
            raise Exception(
                "`pynini` is not installed ! \n"
                "Please run the `nemo_text_processing/setup.sh` script"
                "prior to usage of this toolkit.")
        self.cg_normalizer = NormalizerWithAudio(input_case=input_case,
                                                 lang=self.lang)

    # Training
    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        # Apply Transformer
        outputs = self.model(
            input_ids=batch['input_ids'],
            decoder_input_ids=batch['decoder_input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['labels'],
        )
        train_loss = outputs.loss

        lr = self._optimizer.param_groups[0]['lr']
        self.log('train_loss', train_loss)
        self.log('lr', lr, prog_bar=True)
        return {'loss': train_loss, 'lr': lr}

    # Validation and Testing
    def validation_step(self, batch, batch_idx):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """

        # Apply Transformer
        outputs = self.model(
            input_ids=batch['input_ids'],
            decoder_input_ids=batch['decoder_input_ids'],
            attention_mask=batch['attention_mask'],
            labels=batch['labels'],
        )
        val_loss = outputs.loss

        return {'val_loss': val_loss}

    def validation_epoch_end(self, outputs):
        """
        Called at the end of validation to aggregate outputs.
        :param outputs: list of individual outputs of each validation step.
        """
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        self.log('val_loss', avg_loss)

        return {
            'val_loss': avg_loss,
        }

    def test_step(self, batch, batch_idx):
        """
        Lightning calls this inside the test loop with the data from the test dataloader
        passed in as `batch`.
        """
        return self.validation_step(batch, batch_idx)

    def test_epoch_end(self, outputs):
        """
        Called at the end of test to aggregate outputs.
        :param outputs: list of individual outputs of each test step.
        """
        return self.validation_epoch_end(outputs)

    # Functions for inference
    @torch.no_grad()
    def _infer(
        self,
        sents: List[List[str]],
        nb_spans: List[int],
        span_starts: List[List[int]],
        span_ends: List[List[int]],
        inst_directions: List[str],
    ):
        """ Main function for Inference
        Args:
            sents: A list of inputs tokenized by a basic tokenizer.
            nb_spans: A list of ints where each int indicates the number of semiotic spans in each input.
            span_starts: A list of lists where each list contains the starting locations of semiotic spans in an input.
            span_ends: A list of lists where each list contains the ending locations of semiotic spans in an input.
            inst_directions: A list of str where each str indicates the direction of the corresponding instance (i.e., INST_BACKWARD for ITN or INST_FORWARD for TN).

        Returns: A list of lists where each list contains the decoded spans for the corresponding input.
        """
        self.eval()

        if sum(nb_spans) == 0:
            return [[]] * len(sents)
        model, tokenizer = self.model, self._tokenizer
        try:
            model_max_len = model.config.n_positions
        except AttributeError:
            model_max_len = 512
        ctx_size = constants.DECODE_CTX_SIZE
        extra_id_0 = constants.EXTRA_ID_0
        extra_id_1 = constants.EXTRA_ID_1

        # Build all_inputs
        input_centers, input_dirs, all_inputs = [], [], []
        for ix, sent in enumerate(sents):
            cur_inputs = []
            for jx in range(nb_spans[ix]):
                cur_start = span_starts[ix][jx]
                cur_end = span_ends[ix][jx]
                ctx_left = sent[max(0, cur_start - ctx_size):cur_start]
                ctx_right = sent[cur_end + 1:cur_end + 1 + ctx_size]
                span_words = sent[cur_start:cur_end + 1]
                span_words_str = ' '.join(span_words)
                if is_url(span_words_str):
                    span_words_str = span_words_str.lower()
                input_centers.append(span_words_str)
                input_dirs.append(inst_directions[ix])
                # Build cur_inputs
                if inst_directions[ix] == constants.INST_BACKWARD:
                    cur_inputs = [constants.ITN_PREFIX]
                if inst_directions[ix] == constants.INST_FORWARD:
                    cur_inputs = [constants.TN_PREFIX]
                cur_inputs += ctx_left
                cur_inputs += [extra_id_0
                               ] + span_words_str.split(' ') + [extra_id_1]
                cur_inputs += ctx_right
                all_inputs.append(' '.join(cur_inputs))

        # Apply the decoding model
        batch = tokenizer(all_inputs, padding=True, return_tensors='pt')
        input_ids = batch['input_ids'].to(self.device)
        outputs = model.generate(input_ids,
                                 output_scores=True,
                                 return_dict_in_generate=True,
                                 max_length=model_max_len)
        generated_ids, sequence_toks_scores = outputs['sequences'], outputs[
            'scores']
        generated_texts = tokenizer.batch_decode(generated_ids,
                                                 skip_special_tokens=True)

        # Use covering grammars (if enabled)
        if self.use_cg:
            # Compute sequence probabilities
            sequence_probs = torch.ones(len(all_inputs)).to(self.device)
            for ix, cur_toks_scores in enumerate(sequence_toks_scores):
                cur_generated_ids = generated_ids[:, ix + 1].tolist()
                cur_toks_probs = torch.nn.functional.softmax(cur_toks_scores,
                                                             dim=-1)
                # Compute selected_toks_probs
                selected_toks_probs = []
                for jx, _id in enumerate(cur_generated_ids):
                    if _id != self._tokenizer.pad_token_id:
                        selected_toks_probs.append(cur_toks_probs[jx, _id])
                    else:
                        selected_toks_probs.append(1)
                selected_toks_probs = torch.tensor(selected_toks_probs).to(
                    self.device)
                sequence_probs *= selected_toks_probs

            # For TN cases where the neural model is not confident, use CGs
            neural_confidence_threshold = self.neural_confidence_threshold
            for ix, (_dir, _input, _prob) in enumerate(
                    zip(input_dirs, input_centers, sequence_probs)):
                if _dir == constants.INST_FORWARD and _prob < neural_confidence_threshold:
                    if is_url(_input):
                        _input = _input.replace(' ',
                                                '')  # Remove spaces in URLs
                    try:
                        cg_outputs = self.cg_normalizer.normalize(
                            text=_input, verbose=False, n_tagged=self.n_tagged)
                        generated_texts[ix] = list(cg_outputs)[0]
                    except:  # if there is any exception, fall back to the input
                        generated_texts[ix] = _input

        # Post processing
        generated_texts = self.postprocess_output_spans(
            input_centers, generated_texts, input_dirs)

        # Prepare final_texts
        final_texts, span_ctx = [], 0
        for nb_span in nb_spans:
            cur_texts = []
            for i in range(nb_span):
                cur_texts.append(generated_texts[span_ctx])
                span_ctx += 1
            final_texts.append(cur_texts)

        return final_texts

    def postprocess_output_spans(self, input_centers, output_spans,
                                 input_dirs):
        en_greek_writtens = list(constants.EN_GREEK_TO_SPOKEN.keys())
        en_greek_spokens = list(constants.EN_GREEK_TO_SPOKEN.values())
        for ix, (_input, _output) in enumerate(zip(input_centers,
                                                   output_spans)):
            if self.lang == constants.ENGLISH:
                # Handle URL
                if is_url(_input):
                    _output = _output.replace('http', ' h t t p ')
                    _output = _output.replace('/', ' slash ')
                    _output = _output.replace('.', ' dot ')
                    _output = _output.replace(':', ' colon ')
                    _output = _output.replace('-', ' dash ')
                    _output = _output.replace('_', ' underscore ')
                    _output = _output.replace('%', ' percent ')
                    _output = _output.replace('www', ' w w w ')
                    _output = _output.replace('ftp', ' f t p ')
                    output_spans[ix] = ' '.join(wordninja.split(_output))
                    continue
                # Greek letters
                if _input in en_greek_writtens:
                    if input_dirs[ix] == constants.INST_FORWARD:
                        output_spans[ix] = constants.EN_GREEK_TO_SPOKEN[_input]
                if _input in en_greek_spokens:
                    if input_dirs[ix] == constants.INST_FORWARD:
                        output_spans[ix] = _input
                    if input_dirs[ix] == constants.INST_BACKWARD:
                        output_spans[ix] = constants.EN_SPOKEN_TO_GREEK[_input]
        return output_spans

    # Functions for processing data
    def setup_training_data(self, train_data_config: Optional[DictConfig]):
        if not train_data_config or not train_data_config.data_path:
            logging.info(
                f"Dataloader config or file_path for the train is missing, so no data loader for train is created!"
            )
            self.train_dataset, self._train_dl = None, None
            return
        self.train_dataset, self._train_dl = self._setup_dataloader_from_config(
            cfg=train_data_config, mode="train")

    def setup_validation_data(self, val_data_config: Optional[DictConfig]):
        if not val_data_config or not val_data_config.data_path:
            logging.info(
                f"Dataloader config or file_path for the validation is missing, so no data loader for validation is created!"
            )
            self.validation_dataset, self._validation_dl = None, None
            return
        self.validation_dataset, self._validation_dl = self._setup_dataloader_from_config(
            cfg=val_data_config, mode="val")

    def setup_test_data(self, test_data_config: Optional[DictConfig]):
        if not test_data_config or test_data_config.data_path is None:
            logging.info(
                f"Dataloader config or file_path for the test is missing, so no data loader for test is created!"
            )
            self.test_dataset, self._test_dl = None, None
            return
        self.test_dataset, self._test_dl = self._setup_dataloader_from_config(
            cfg=test_data_config, mode="test")

    def _setup_dataloader_from_config(self, cfg: DictConfig, mode: str):
        tokenizer, model = self._tokenizer, self.model
        start_time = perf_counter()
        logging.info(f'Creating {mode} dataset')
        input_file = cfg.data_path
        dataset = TextNormalizationDecoderDataset(
            input_file,
            tokenizer,
            self.transformer_name,
            cfg.mode,
            cfg.get('max_decoder_len', tokenizer.model_max_length),
            cfg.get('decoder_data_augmentation', False),
            cfg.lang,
            cfg.do_basic_tokenize,
            cfg.get('use_cache', False),
            cfg.get('max_insts', -1),
        )
        data_collator = DataCollatorForSeq2Seq(
            tokenizer,
            model=model,
            label_pad_token_id=constants.LABEL_PAD_TOKEN_ID,
        )
        dl = torch.utils.data.DataLoader(dataset=dataset,
                                         batch_size=cfg.batch_size,
                                         shuffle=cfg.shuffle,
                                         collate_fn=data_collator)
        running_time = perf_counter() - start_time
        logging.info(f'Took {running_time} seconds')
        return dataset, dl

    @classmethod
    def list_available_models(cls) -> Optional[PretrainedModelInfo]:
        """
        This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud.
        Returns:
            List of available pre-trained models.
        """
        result = []
        return result