class TestCardinal: inverse_normalizer_en = InverseNormalizer(lang='en') if PYNINI_AVAILABLE else None @parameterized.expand(parse_test_case_file('en/data_inverse_text_normalization/test_cases_cardinal.txt')) @pytest.mark.skipif( not PYNINI_AVAILABLE, reason="`pynini` not installed, please install via nemo_text_processing/setup.sh" ) @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_denorm(self, test_input, expected): pred = self.inverse_normalizer_en.inverse_normalize(test_input, verbose=False) assert pred == expected normalizer_en = Normalizer(input_case='cased', lang='en') if PYNINI_AVAILABLE else None normalizer_with_audio_en = NormalizerWithAudio(input_case='cased', lang='en') if PYNINI_AVAILABLE else None @parameterized.expand(parse_test_case_file('en/data_text_normalization/test_cases_cardinal.txt')) @pytest.mark.skipif( not PYNINI_AVAILABLE, reason="`pynini` not installed, please install via nemo_text_processing/setup.sh" ) @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_norm(self, test_input, expected): pred = self.normalizer_en.normalize(test_input, verbose=False) assert pred == expected pred_non_deterministic = self.normalizer_with_audio_en.normalize(test_input, n_tagged=100) assert expected in pred_non_deterministic
class TestBoundary: normalizer_en = (Normalizer(input_case='cased', lang='en', cache_dir=CACHE_DIR, overwrite_cache=False) if PYNINI_AVAILABLE else None) normalizer_with_audio_en = (NormalizerWithAudio(input_case='cased', lang='en', cache_dir=CACHE_DIR, overwrite_cache=False) if PYNINI_AVAILABLE and CACHE_DIR else None) @parameterized.expand( parse_test_case_file( 'en/data_text_normalization/test_cases_boundary.txt')) @pytest.mark.skipif( not PYNINI_AVAILABLE, reason= "`pynini` not installed, please install via nemo_text_processing/setup.sh" ) @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_norm(self, test_input, expected): pred = self.normalizer_en.normalize(test_input, verbose=False) assert pred == expected if self.normalizer_with_audio_en: pred_non_deterministic = self.normalizer_with_audio_en.normalize( test_input, n_tagged=30, punct_post_process=False) assert expected in pred_non_deterministic
class TestBoundary: normalizer = Normalizer(input_case='cased') if PYNINI_AVAILABLE else None normalizer_with_audio = NormalizerWithAudio( input_case='cased') if PYNINI_AVAILABLE else None @parameterized.expand( parse_test_case_file('data_text_normalization/test_cases_boundary.txt') ) @pytest.mark.skipif( not PYNINI_AVAILABLE, reason= "`pynini` not installed, please install via nemo_text_processing/setup.sh" ) @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_norm(self, test_input, expected): pred = self.normalizer.normalize(test_input, verbose=False) assert pred == expected pred_non_deterministic = self.normalizer_with_audio.normalize( test_input, n_tagged=100, punct_pre_process=False, punct_post_process=False) assert expected in pred_non_deterministic
class TestDate: inverse_normalizer = InverseNormalizer() if PYNINI_AVAILABLE else None @parameterized.expand( parse_test_case_file( 'data_inverse_text_normalization/test_cases_date.txt')) @pytest.mark.skipif( not PYNINI_AVAILABLE, reason= "`pynini` not installed, please install via nemo_text_processing/setup.sh" ) @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_denorm(self, test_input, expected): pred = self.inverse_normalizer.inverse_normalize(test_input, verbose=False) assert pred == expected normalizer = Normalizer(input_case='cased') if PYNINI_AVAILABLE else None normalizer_with_audio = NormalizerWithAudio( input_case='cased') if PYNINI_AVAILABLE else None @parameterized.expand( parse_test_case_file('data_text_normalization/test_cases_date.txt')) @pytest.mark.skipif( not PYNINI_AVAILABLE, reason= "`pynini` not installed, please install via nemo_text_processing/setup.sh" ) @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_norm_uncased(self, test_input, expected): pred = self.normalizer.normalize(test_input, verbose=False) assert pred == expected pred_non_deterministic = self.normalizer_with_audio.normalize( test_input, n_tagged=100) assert expected in pred_non_deterministic normalizer_uppercased = Normalizer( input_case='cased') if PYNINI_AVAILABLE else None cases_uppercased = { "Aug. 8": "august eighth", "8 Aug.": "the eighth of august", "aug. 8": "august eighth" } @parameterized.expand(cases_uppercased.items()) @pytest.mark.skipif( not PYNINI_AVAILABLE, reason= "`pynini` not installed, please install via nemo_text_processing/setup.sh" ) @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_norm_cased(self, test_input, expected): pred = self.normalizer_uppercased.normalize(test_input, verbose=False) assert pred == expected pred_non_deterministic = self.normalizer_with_audio.normalize( test_input, n_tagged=100) assert expected in pred_non_deterministic
class TestSpecialText: normalizer_en = (Normalizer(input_case='cased', lang='en', cache_dir=CACHE_DIR, overwrite_cache=False) if PYNINI_AVAILABLE else None) normalizer_with_audio_en = (NormalizerWithAudio(input_case='cased', lang='en', cache_dir=CACHE_DIR, overwrite_cache=False) if PYNINI_AVAILABLE and RUN_AUDIO_BASED_TESTS else None) @parameterized.expand( parse_test_case_file( 'en/data_text_normalization/test_cases_special_text.txt')) @pytest.mark.skipif( not PYNINI_AVAILABLE, reason= "`pynini` not installed, please install via nemo_text_processing/setup.sh" ) @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_norm(self, test_input, expected): pred = self.normalizer_en.normalize(test_input, verbose=False) assert pred == expected
class TestRoman: normalizer_en = (Normalizer(input_case='cased', lang='en', cache_dir=CACHE_DIR, overwrite_cache=False) if PYNINI_AVAILABLE else None) normalizer_with_audio_en = (NormalizerWithAudio(input_case='cased', lang='en', cache_dir=CACHE_DIR, overwrite_cache=False) if PYNINI_AVAILABLE and RUN_AUDIO_BASED_TESTS else None) # address is tagged by the measure class @parameterized.expand( parse_test_case_file('en/data_text_normalization/test_cases_roman.txt') ) @pytest.mark.skipif( not PYNINI_AVAILABLE, reason= "`pynini` not installed, please install via nemo_text_processing/setup.sh" ) @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_norm(self, test_input, expected): # pred = self.normalizer_en.normalize(test_input, verbose=False) # assert pred == expected # # if self.normalizer_with_audio_en: # pred_non_deterministic = self.normalizer_with_audio_en.normalize( # test_input, n_tagged=30, punct_post_process=False, # ) # assert expected in pred_non_deterministic pass
class TestNormalizeWithAudio: normalizer_de = (NormalizerWithAudio(input_case='cased', lang='de', cache_dir=CACHE_DIR, overwrite_cache=False) if PYNINI_AVAILABLE else None) @parameterized.expand( get_test_cases_multiple( 'de/data_text_normalization/test_cases_normalize_with_audio.txt')) @pytest.mark.skipif( not PYNINI_AVAILABLE, reason= "`pynini` not installed, please install via nemo_text_processing/setup.sh" ) @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_norm(self, test_input, expected): pred = self.normalizer_de.normalize(test_input, n_tagged=1000, punct_post_process=False) print(expected) print("pred") print(pred) assert len(set(pred).intersection(set(expected))) == len( expected), f'missing: {set(expected).difference(set(pred))}'
def setup_cgs(self, cfg: DictConfig): """ Setup covering grammars (if enabled). :param cfg: Configs of the decoder model. """ self.use_cg = True self.neural_confidence_threshold = cfg.get('neural_confidence_threshold', 0.99) self.n_tagged = cfg.get('n_tagged', 1) input_case = 'cased' # input_case is cased by default if hasattr(self._tokenizer, 'do_lower_case') and self._tokenizer.do_lower_case: input_case = 'lower_cased' if not PYNINI_AVAILABLE: raise Exception( "`pynini` is not installed ! \n" "Please run the `nemo_text_processing/setup.sh` script" "prior to usage of this toolkit." ) self.cg_normalizer = NormalizerWithAudio(input_case=input_case, lang=self.lang)
class TestNormalizeWithAudio: normalizer = NormalizerWithAudio(input_case='cased') if PYNINI_AVAILABLE else None @parameterized.expand(get_test_cases_multiple('data_text_normalization/test_cases_normalize_with_audio.txt')) @pytest.mark.skipif( not PYNINI_AVAILABLE, reason="`pynini` not installed, please install via nemo_text_processing/setup.sh" ) @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_norm(self, test_input, expected): pred = self.normalizer.normalize(test_input, n_tagged=700) assert len(set(pred).intersection(set(expected))) == len(expected), f'pred: {pred}'
class TestWord: inverse_normalizer_en = (InverseNormalizer( lang='en', cache_dir=CACHE_DIR, overwrite_cache=False) if PYNINI_AVAILABLE else None) @parameterized.expand( parse_test_case_file( 'en/data_inverse_text_normalization/test_cases_word.txt')) @pytest.mark.skipif( not PYNINI_AVAILABLE, reason= "`pynini` not installed, please install via nemo_text_processing/setup.sh" ) @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_denorm(self, test_input, expected): pred = self.inverse_normalizer_en.inverse_normalize(test_input, verbose=False) assert pred == expected normalizer_en = (Normalizer(input_case='cased', lang='en', cache_dir=CACHE_DIR, overwrite_cache=False, post_process=True) if PYNINI_AVAILABLE else None) normalizer_with_audio_en = (NormalizerWithAudio(input_case='cased', lang='en', cache_dir=CACHE_DIR, overwrite_cache=False) if PYNINI_AVAILABLE and RUN_AUDIO_BASED_TESTS else None) @parameterized.expand( parse_test_case_file('en/data_text_normalization/test_cases_word.txt')) @pytest.mark.skipif( not PYNINI_AVAILABLE, reason= "`pynini` not installed, please install via nemo_text_processing/setup.sh" ) @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_norm(self, test_input, expected): pred = self.normalizer_en.normalize(test_input, verbose=False) assert pred == expected, f"input: {test_input} != {expected}" if self.normalizer_with_audio_en: pred_non_deterministic = self.normalizer_with_audio_en.normalize( test_input, n_tagged=3, punct_post_process=False) assert expected in pred_non_deterministic, f"input: {test_input}"
class TestWhitelist: inverse_normalizer_en = InverseNormalizer(lang='en') if PYNINI_AVAILABLE else None @parameterized.expand(parse_test_case_file('en/data_inverse_text_normalization/test_cases_whitelist.txt')) @pytest.mark.skipif( not PYNINI_AVAILABLE, reason="`pynini` not installed, please install via nemo_text_processing/setup.sh" ) @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_denorm(self, test_input, expected): pred = self.inverse_normalizer_en.inverse_normalize(test_input, verbose=False) assert pred == expected normalizer_en = Normalizer(input_case='lower_cased') if PYNINI_AVAILABLE else None normalizer_with_audio_en = NormalizerWithAudio(input_case='cased', lang='en') if PYNINI_AVAILABLE else None @parameterized.expand(parse_test_case_file('en/data_text_normalization/test_cases_whitelist.txt')) @pytest.mark.skipif( not PYNINI_AVAILABLE, reason="`pynini` not installed, please install via nemo_text_processing/setup.sh" ) @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_norm(self, test_input, expected): pred = self.normalizer_en.normalize(test_input, verbose=False) assert pred == expected pred_non_deterministic = self.normalizer_with_audio_en.normalize(test_input, n_tagged=100) assert expected in pred_non_deterministic normalizer_uppercased = Normalizer(input_case='cased', lang='en') if PYNINI_AVAILABLE else None cases_uppercased = {"Dr. Evil": "doctor Evil", "No. 4": "number four", "dr. Evil": "dr. Evil", "no. 4": "no. four"} @parameterized.expand(cases_uppercased.items()) @pytest.mark.skipif( not PYNINI_AVAILABLE, reason="`pynini` not installed, please install via nemo_text_processing/setup.sh" ) @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_norm_cased(self, test_input, expected): pred = self.normalizer_uppercased.normalize(test_input, verbose=False) assert pred == expected pred_non_deterministic = self.normalizer_with_audio_en.normalize(test_input, n_tagged=100) assert expected in pred_non_deterministic
class TestRuNormalizeWithAudio: normalizer = NormalizerWithAudio(input_case='cased', lang='ru', cache_dir=CACHE_DIR) if PYNINI_AVAILABLE else None
class TestRuNormalizeWithAudio: normalizer = NormalizerWithAudio( input_case='cased', lang='ru', cache_dir=CACHE_DIR) if PYNINI_AVAILABLE else None N_TAGGED = 3000 @parameterized.expand( parse_test_case_file( 'ru/data_text_normalization/test_cases_cardinal.txt')) @pytest.mark.skipif( not PYNINI_AVAILABLE, reason= "`pynini` not installed, please install via nemo_text_processing/setup.sh" ) @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_norm_cardinal(self, expected, test_input): preds = self.normalizer.normalize(test_input, n_tagged=self.N_TAGGED) assert expected in preds @parameterized.expand( parse_test_case_file( 'ru/data_text_normalization/test_cases_ordinal.txt')) @pytest.mark.skipif( not PYNINI_AVAILABLE, reason= "`pynini` not installed, please install via nemo_text_processing/setup.sh" ) @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_norm_ordinal(self, expected, test_input): preds = self.normalizer.normalize(test_input, n_tagged=self.N_TAGGED) assert expected in preds @parameterized.expand( parse_test_case_file( 'ru/data_text_normalization/test_cases_decimal.txt')) @pytest.mark.skipif( not PYNINI_AVAILABLE, reason= "`pynini` not installed, please install via nemo_text_processing/setup.sh" ) @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_norm_decimal(self, expected, test_input): preds = self.normalizer.normalize(test_input, n_tagged=5000) assert expected in preds @parameterized.expand( parse_test_case_file( 'ru/data_text_normalization/test_cases_measure.txt')) @pytest.mark.skipif( not PYNINI_AVAILABLE, reason= "`pynini` not installed, please install via nemo_text_processing/setup.sh" ) @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_norm_measure(self, expected, test_input): preds = self.normalizer.normalize(test_input, n_tagged=self.N_TAGGED) assert expected in preds @parameterized.expand( parse_test_case_file('ru/data_text_normalization/test_cases_date.txt')) @pytest.mark.skipif( not PYNINI_AVAILABLE, reason= "`pynini` not installed, please install via nemo_text_processing/setup.sh" ) @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_norm_date(self, expected, test_input): preds = self.normalizer.normalize(test_input, n_tagged=-1) assert expected in preds, expected not in preds @parameterized.expand( parse_test_case_file( 'ru/data_text_normalization/test_cases_telephone.txt')) @pytest.mark.skipif( not PYNINI_AVAILABLE, reason= "`pynini` not installed, please install via nemo_text_processing/setup.sh" ) @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_norm_telephone(self, expected, test_input): preds = self.normalizer.normalize(test_input, n_tagged=-1) assert expected in preds, expected not in preds @parameterized.expand( parse_test_case_file('ru/data_text_normalization/test_cases_money.txt') ) @pytest.mark.skipif( not PYNINI_AVAILABLE, reason= "`pynini` not installed, please install via nemo_text_processing/setup.sh" ) @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_norm_money(self, expected, test_input): preds = self.normalizer.normalize(test_input, n_tagged=-1) assert expected in preds, expected not in preds @parameterized.expand( parse_test_case_file('ru/data_text_normalization/test_cases_time.txt')) @pytest.mark.skipif( not PYNINI_AVAILABLE, reason= "`pynini` not installed, please install via nemo_text_processing/setup.sh" ) @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_norm_time(self, expected, test_input): preds = self.normalizer.normalize(test_input, n_tagged=-1) assert expected in preds, expected not in preds @parameterized.expand( parse_test_case_file( 'ru/data_text_normalization/test_cases_electronic.txt')) @pytest.mark.skipif( not PYNINI_AVAILABLE, reason= "`pynini` not installed, please install via nemo_text_processing/setup.sh" ) @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_norm_electronic(self, expected, test_input): preds = self.normalizer.normalize(test_input, n_tagged=self.N_TAGGED) assert expected in preds @parameterized.expand( parse_test_case_file( 'ru/data_text_normalization/test_cases_whitelist.txt')) @pytest.mark.skipif( not PYNINI_AVAILABLE, reason= "`pynini` not installed, please install via nemo_text_processing/setup.sh" ) @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_norm_whitelist(self, expected, test_input): preds = self.normalizer.normalize(test_input, n_tagged=self.N_TAGGED) assert expected in preds @parameterized.expand( parse_test_case_file('ru/data_text_normalization/test_cases_word.txt')) @pytest.mark.skipif( not PYNINI_AVAILABLE, reason= "`pynini` not installed, please install via nemo_text_processing/setup.sh" ) @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_norm_word(self, expected, test_input): preds = self.normalizer.normalize(test_input, n_tagged=self.N_TAGGED) assert expected in preds
class DuplexDecoderModel(NLPModel): """ Transformer-based (duplex) decoder model for TN/ITN. """ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Get global rank and total number of GPU workers for IterableDataset partitioning, if applicable # Global_rank and local_rank is set by LightningModule in Lightning 1.2.0 self.world_size = 1 if trainer is not None: self.world_size = trainer.num_nodes * trainer.num_gpus self._tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer) super().__init__(cfg=cfg, trainer=trainer) self.model = AutoModelForSeq2SeqLM.from_pretrained(cfg.transformer) self.max_sequence_len = cfg.get('max_sequence_len', self._tokenizer.model_max_length) self.mode = cfg.get('mode', 'joint') self.transformer_name = cfg.transformer # Language self.lang = cfg.get('lang', None) # Covering Grammars self.cg_normalizer = None # Default # We only support integrating with English TN covering grammars at the moment self.use_cg = cfg.get('use_cg', False) and self.lang == constants.ENGLISH if self.use_cg: self.setup_cgs(cfg) # setup processor for detokenization self.processor = MosesProcessor(lang_id=self.lang) # Setup covering grammars (if enabled) def setup_cgs(self, cfg: DictConfig): """ Setup covering grammars (if enabled). :param cfg: Configs of the decoder model. """ self.use_cg = True self.neural_confidence_threshold = cfg.get( 'neural_confidence_threshold', 0.99) self.n_tagged = cfg.get('n_tagged', 1) input_case = 'cased' # input_case is cased by default if hasattr(self._tokenizer, 'do_lower_case') and self._tokenizer.do_lower_case: input_case = 'lower_cased' if not PYNINI_AVAILABLE: raise Exception( "`pynini` is not installed ! \n" "Please run the `nemo_text_processing/setup.sh` script" "prior to usage of this toolkit.") self.cg_normalizer = NormalizerWithAudio(input_case=input_case, lang=self.lang) # Training def training_step(self, batch, batch_idx): """ Lightning calls this inside the training loop with the data from the training dataloader passed in as `batch`. """ # tarred dataset contains batches, and the first dimension of size 1 added by the DataLoader # (batch_size is set to 1) is redundant if batch['input_ids'].ndim == 3: batch = {k: v.squeeze(dim=0) for k, v in batch.items()} # Apply Transformer outputs = self.model( input_ids=batch['input_ids'], decoder_input_ids=batch['decoder_input_ids'], attention_mask=batch['attention_mask'], labels=batch['labels'], ) train_loss = outputs.loss lr = self._optimizer.param_groups[0]['lr'] self.log('train_loss', train_loss) self.log('lr', lr, prog_bar=True) return {'loss': train_loss, 'lr': lr} # Validation and Testing def validation_step(self, batch, batch_idx, dataloader_idx=0, split="val"): """ Lightning calls this inside the validation loop with the data from the validation dataloader passed in as `batch`. """ # Apply Transformer outputs = self.model( input_ids=batch['input_ids'], decoder_input_ids=batch['decoder_input_ids'], attention_mask=batch['attention_mask'], labels=batch['labels'], ) val_loss = outputs.loss labels_str = self._tokenizer.batch_decode( torch.ones_like(batch['labels']) * ((batch['labels'] == -100) * 100) + batch['labels'], skip_special_tokens=True, ) generated_texts, _, _ = self._generate_predictions( input_ids=batch['input_ids'], model_max_len=self.max_sequence_len) input_centers = self._tokenizer.batch_decode(batch['input_center'], skip_special_tokens=True) direction = [x[0].item() for x in batch['direction']] direction_str = [constants.DIRECTIONS_ID_TO_NAME[x] for x in direction] # apply post_processing generated_texts = self.postprocess_output_spans( input_centers, generated_texts, direction_str) results = defaultdict(int) for idx, class_id in enumerate(batch['semiotic_class_id']): direction = constants.TASK_ID_TO_MODE[batch['direction'][idx] [0].item()] class_name = self._val_id_to_class[dataloader_idx][ class_id[0].item()] results[f"correct_{class_name}_{direction}"] += torch.tensor( labels_str[idx] == generated_texts[idx], dtype=torch.int).to(self.device) results[f"total_{class_name}_{direction}"] += torch.tensor(1).to( self.device) results[f"{split}_loss"] = val_loss return dict(results) def multi_validation_epoch_end(self, outputs: List, dataloader_idx=0, split="val"): """ Called at the end of validation to aggregate outputs. Args: outputs: list of individual outputs of each validation step. """ avg_loss = torch.stack([x[f'{split}_loss'] for x in outputs]).mean() # create a dictionary to store all the results results = {} directions = [constants.TN_MODE, constants.ITN_MODE ] if self.mode == constants.JOINT_MODE else [self.mode] for class_name in self._val_class_to_id[dataloader_idx]: for direction in directions: results[f"correct_{class_name}_{direction}"] = 0 results[f"total_{class_name}_{direction}"] = 0 for key in results: count = [x[key] for x in outputs if key in x] count = torch.stack(count).sum( ) if len(count) > 0 else torch.tensor(0).to(self.device) results[key] = count all_results = defaultdict(list) if torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() for ind in range(world_size): for key, v in results.items(): all_results[key].append(torch.empty_like(v)) for key, v in results.items(): torch.distributed.all_gather(all_results[key], v) else: for key, v in results.items(): all_results[key].append(v) if not torch.distributed.is_initialized( ) or torch.distributed.get_rank() == 0: if split == "test": val_name = self._test_names[dataloader_idx].upper() else: val_name = self._validation_names[dataloader_idx].upper() final_results = defaultdict(int) for key, v in all_results.items(): for _v in v: final_results[key] += _v.item() accuracies = defaultdict(dict) for key, value in final_results.items(): if "total_" in key: _, class_name, mode = key.split('_') correct = final_results[f"correct_{class_name}_{mode}"] if value == 0: accuracies[mode][class_name] = (0, correct, value) else: acc = round(correct / value * 100, 3) accuracies[mode][class_name] = (acc, correct, value) for mode, values in accuracies.items(): report = f"Accuracy {mode.upper()} task {val_name}:\n" report += '\n'.join([ get_formatted_string( (class_name, f'{v[0]}% ({v[1]}/{v[2]})'), str_max_len=24) for class_name, v in values.items() ]) # calculate average across all classes all_total = 0 all_correct = 0 for _, class_values in values.items(): _, correct, total = class_values all_correct += correct all_total += total all_acc = round( (all_correct / all_total) * 100, 3) if all_total > 0 else 0 report += '\n' + get_formatted_string( ('AVG', f'{all_acc}% ({all_correct}/{all_total})'), str_max_len=24) logging.info(report) accuracies[mode]['AVG'] = [all_acc] self.log(f'{split}_loss', avg_loss) if self.trainer.is_global_zero: for mode in accuracies: for class_name, values in accuracies[mode].items(): self.log( f'{val_name}_{mode.upper()}_acc_{class_name.upper()}', values[0], rank_zero_only=True) return { f'{split}_loss': avg_loss, } def test_step(self, batch, batch_idx, dataloader_idx: int = 0): """ Lightning calls this inside the test loop with the data from the test dataloader passed in as `batch`. """ return self.validation_step(batch, batch_idx, dataloader_idx, split="test") def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): """ Called at the end of test to aggregate outputs. outputs: list of individual outputs of each test step. """ return self.multi_validation_epoch_end(outputs, dataloader_idx, split="test") @torch.no_grad() def _generate_predictions(self, input_ids: torch.Tensor, model_max_len: int = 512): """ Generates predictions """ outputs = self.model.generate(input_ids, output_scores=True, return_dict_in_generate=True, max_length=model_max_len) generated_ids, sequence_toks_scores = outputs['sequences'], outputs[ 'scores'] generated_texts = self._tokenizer.batch_decode( generated_ids, skip_special_tokens=True) return generated_texts, generated_ids, sequence_toks_scores # Functions for inference @torch.no_grad() def _infer( self, sents: List[List[str]], nb_spans: List[int], span_starts: List[List[int]], span_ends: List[List[int]], inst_directions: List[str], ): """ Main function for Inference Args: sents: A list of inputs tokenized by a basic tokenizer. nb_spans: A list of ints where each int indicates the number of semiotic spans in each input. span_starts: A list of lists where each list contains the starting locations of semiotic spans in an input. span_ends: A list of lists where each list contains the ending locations of semiotic spans in an input. inst_directions: A list of str where each str indicates the direction of the corresponding instance (i.e., INST_BACKWARD for ITN or INST_FORWARD for TN). Returns: A list of lists where each list contains the decoded spans for the corresponding input. """ self.eval() if sum(nb_spans) == 0: return [[]] * len(sents) model, tokenizer = self.model, self._tokenizer try: model_max_len = model.config.n_positions except AttributeError: model_max_len = 512 ctx_size = constants.DECODE_CTX_SIZE extra_id_0 = constants.EXTRA_ID_0 extra_id_1 = constants.EXTRA_ID_1 """ Build all_inputs - extracted spans to be transformed by the decoder model Inputs for TN direction have "0" prefix, while the backward, ITN direction, has prefix "1" "input_centers" - List[str] - ground-truth labels for the span #TODO: rename """ input_centers, input_dirs, all_inputs = [], [], [] for ix, sent in enumerate(sents): cur_inputs = [] for jx in range(nb_spans[ix]): cur_start = span_starts[ix][jx] cur_end = span_ends[ix][jx] ctx_left = sent[max(0, cur_start - ctx_size):cur_start] ctx_right = sent[cur_end + 1:cur_end + 1 + ctx_size] span_words = sent[cur_start:cur_end + 1] span_words_str = ' '.join(span_words) if is_url(span_words_str): span_words_str = span_words_str.lower() input_centers.append(span_words_str) input_dirs.append(inst_directions[ix]) # Build cur_inputs if inst_directions[ix] == constants.INST_BACKWARD: cur_inputs = [constants.ITN_PREFIX] if inst_directions[ix] == constants.INST_FORWARD: cur_inputs = [constants.TN_PREFIX] cur_inputs += ctx_left cur_inputs += [extra_id_0 ] + span_words_str.split(' ') + [extra_id_1] cur_inputs += ctx_right all_inputs.append(' '.join(cur_inputs)) # Apply the decoding model batch = tokenizer(all_inputs, padding=True, return_tensors='pt') input_ids = batch['input_ids'].to(self.device) generated_texts, generated_ids, sequence_toks_scores = self._generate_predictions( input_ids=input_ids, model_max_len=model_max_len) # Use covering grammars (if enabled) if self.use_cg: # Compute sequence probabilities sequence_probs = torch.ones(len(all_inputs)).to(self.device) for ix, cur_toks_scores in enumerate(sequence_toks_scores): cur_generated_ids = generated_ids[:, ix + 1].tolist() cur_toks_probs = torch.nn.functional.softmax(cur_toks_scores, dim=-1) # Compute selected_toks_probs selected_toks_probs = [] for jx, _id in enumerate(cur_generated_ids): if _id != self._tokenizer.pad_token_id: selected_toks_probs.append(cur_toks_probs[jx, _id]) else: selected_toks_probs.append(1) selected_toks_probs = torch.tensor(selected_toks_probs).to( self.device) sequence_probs *= selected_toks_probs # For TN cases where the neural model is not confident, use CGs neural_confidence_threshold = self.neural_confidence_threshold for ix, (_dir, _input, _prob) in enumerate( zip(input_dirs, input_centers, sequence_probs)): if _dir == constants.INST_FORWARD and _prob < neural_confidence_threshold: if is_url(_input): _input = _input.replace(' ', '') # Remove spaces in URLs try: cg_outputs = self.cg_normalizer.normalize( text=_input, verbose=False, n_tagged=self.n_tagged) generated_texts[ix] = list(cg_outputs)[0] except: # if there is any exception, fall back to the input generated_texts[ix] = _input # Post processing generated_texts = self.postprocess_output_spans( input_centers, generated_texts, input_dirs) # Prepare final_texts final_texts, span_ctx = [], 0 for nb_span in nb_spans: cur_texts = [] for i in range(nb_span): cur_texts.append(generated_texts[span_ctx]) span_ctx += 1 final_texts.append(cur_texts) return final_texts def postprocess_output_spans(self, input_centers: List[str], generated_spans: List[str], input_dirs: List[str]): """ Post processing of the generated texts Args: input_centers: Input str (no special tokens or context) generated_spans: Generated spans input_dirs: task direction: constants.INST_BACKWARD or constants.INST_FORWARD Returns: Processing texts """ en_greek_writtens = list(constants.EN_GREEK_TO_SPOKEN.keys()) en_greek_spokens = list(constants.EN_GREEK_TO_SPOKEN.values()) for ix, (_input, _output) in enumerate(zip(input_centers, generated_spans)): if self.lang == constants.ENGLISH: # Handle URL if is_url(_input): _output = _output.replace('http', ' h t t p ') _output = _output.replace('/', ' slash ') _output = _output.replace('.', ' dot ') _output = _output.replace(':', ' colon ') _output = _output.replace('-', ' dash ') _output = _output.replace('_', ' underscore ') _output = _output.replace('%', ' percent ') _output = _output.replace('www', ' w w w ') _output = _output.replace('ftp', ' f t p ') generated_spans[ix] = ' '.join(wordninja.split(_output)) continue # Greek letters if _input in en_greek_writtens: if input_dirs[ix] == constants.INST_FORWARD: generated_spans[ix] = constants.EN_GREEK_TO_SPOKEN[ _input] if _input in en_greek_spokens: if input_dirs[ix] == constants.INST_FORWARD: generated_spans[ix] = _input if input_dirs[ix] == constants.INST_BACKWARD: generated_spans[ix] = constants.EN_SPOKEN_TO_GREEK[ _input] return generated_spans # Functions for processing data def setup_training_data(self, train_data_config: Optional[DictConfig]): if not train_data_config or not train_data_config.data_path: logging.info( f"Dataloader config or file_path for the train is missing, so no data loader for train is created!" ) self.train_dataset, self._train_dl = None, None return self.train_dataset, self._train_dl = self._setup_dataloader_from_config( cfg=train_data_config, data_split="train") def setup_validation_data(self, val_data_config: Optional[DictConfig]): if not val_data_config or not val_data_config.data_path: logging.info( f"Dataloader config or file_path for the validation is missing, so no data loader for validation is created!" ) self.validation_dataset, self._validation_dl = None, None return self.validation_dataset, self._validation_dl = self._setup_dataloader_from_config( cfg=val_data_config, data_split="val") def setup_multiple_validation_data(self, val_data_config: Union[DictConfig, Dict] = None): if val_data_config is None: val_data_config = self._cfg.validation_ds return super().setup_multiple_validation_data(val_data_config) def setup_multiple_test_data(self, test_data_config: Union[DictConfig, Dict] = None): if test_data_config is None: test_data_config = self._cfg.test_ds return super().setup_multiple_test_data(test_data_config) def setup_test_data(self, test_data_config: Optional[DictConfig]): if not test_data_config or test_data_config.data_path is None: logging.info( f"Dataloader config or file_path for the test is missing, so no data loader for test is created!" ) self.test_dataset, self._test_dl = None, None return self.test_dataset, self._test_dl = self._setup_dataloader_from_config( cfg=test_data_config, data_split="test") def _setup_dataloader_from_config(self, cfg: DictConfig, data_split: str): logging.info(f"Creating {data_split} dataset") shuffle = cfg["shuffle"] if cfg.get("use_tarred_dataset", False): logging.info('Tarred dataset') metadata_file = cfg["tar_metadata_file"] if metadata_file is None or not os.path.exists(metadata_file): raise FileNotFoundError( f"Trying to use tarred dataset but could not find {metadata_file}." ) with open(metadata_file, "r") as f: metadata = json.load(f) num_batches = metadata["num_batches"] tar_files = os.path.join(os.path.dirname(metadata_file), metadata["text_tar_filepaths"]) logging.info(f"Loading {tar_files}") dataset = TarredTextNormalizationDecoderDataset( text_tar_filepaths=tar_files, num_batches=num_batches, shuffle_n=cfg.get("tar_shuffle_n", 4 * cfg['batch_size']) if shuffle else 0, shard_strategy=cfg.get("shard_strategy", "scatter"), global_rank=self.global_rank, world_size=self.world_size, ) dl = torch.utils.data.DataLoader( dataset=dataset, batch_size=1, sampler=None, num_workers=cfg.get("num_workers", 2), pin_memory=cfg.get("pin_memory", False), drop_last=cfg.get("drop_last", False), ) else: input_file = cfg.data_path if not os.path.exists(input_file): raise ValueError(f"{input_file} not found.") dataset = TextNormalizationDecoderDataset( input_file=input_file, tokenizer=self._tokenizer, tokenizer_name=self.transformer_name, mode=self.mode, max_len=self.max_sequence_len, decoder_data_augmentation=cfg.get('decoder_data_augmentation', False) if data_split == "train" else False, lang=self.lang, do_basic_tokenize=cfg.do_basic_tokenize, use_cache=cfg.get('use_cache', False), max_insts=cfg.get('max_insts', -1), do_tokenize=True, ) # create and save class names to class_ids mapping for validation # (each validation set might have different classes) if data_split in ['val', 'test']: if not hasattr(self, "_val_class_to_id"): self._val_class_to_id = [] self._val_id_to_class = [] self._val_class_to_id.append(dataset.label_ids_semiotic) self._val_id_to_class.append( {v: k for k, v in dataset.label_ids_semiotic.items()}) data_collator = DataCollatorForSeq2Seq( self._tokenizer, model=self.model, label_pad_token_id=constants.LABEL_PAD_TOKEN_ID, padding=True) dl = torch.utils.data.DataLoader( dataset=dataset, batch_size=cfg.batch_size, shuffle=shuffle, collate_fn=data_collator, num_workers=cfg.get("num_workers", 3), pin_memory=cfg.get("pin_memory", False), drop_last=cfg.get("drop_last", False), ) return dataset, dl @classmethod def list_available_models(cls) -> Optional[PretrainedModelInfo]: """ This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. Returns: List of available pre-trained models. """ result = [] return result
class TestDate: inverse_normalizer_en = (InverseNormalizer( lang='en', cache_dir=CACHE_DIR, overwrite_cache=False) if PYNINI_AVAILABLE else None) @parameterized.expand( parse_test_case_file( 'en/data_inverse_text_normalization/test_cases_date.txt')) @pytest.mark.skipif( not PYNINI_AVAILABLE, reason= "`pynini` not installed, please install via nemo_text_processing/setup.sh" ) @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_denorm(self, test_input, expected): pred = self.inverse_normalizer_en.inverse_normalize(test_input, verbose=False) assert pred == expected normalizer_en = (Normalizer(input_case='cased', lang='en', cache_dir=CACHE_DIR, overwrite_cache=False, post_process=True) if PYNINI_AVAILABLE else None) normalizer_with_audio_en = (NormalizerWithAudio(input_case='cased', lang='en', cache_dir=CACHE_DIR, overwrite_cache=False) if PYNINI_AVAILABLE and RUN_AUDIO_BASED_TESTS else None) @parameterized.expand( parse_test_case_file('en/data_text_normalization/test_cases_date.txt')) @pytest.mark.skipif( not PYNINI_AVAILABLE, reason= "`pynini` not installed, please install via nemo_text_processing/setup.sh" ) @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_norm_uncased(self, test_input, expected): pred = self.normalizer_en.normalize(test_input, verbose=False) assert pred == expected if self.normalizer_with_audio_en: pred_non_deterministic = self.normalizer_with_audio_en.normalize( test_input, punct_post_process=False, n_tagged=100) assert expected in pred_non_deterministic, f"INPUT: {test_input}" normalizer_uppercased = (Normalizer(input_case='cased', lang='en', cache_dir=CACHE_DIR, overwrite_cache=False) if PYNINI_AVAILABLE else None) cases_uppercased = { "Aug. 8": "august eighth", "8 Aug.": "the eighth of august", "aug. 8": "august eighth" } @parameterized.expand(cases_uppercased.items()) @pytest.mark.skipif( not PYNINI_AVAILABLE, reason= "`pynini` not installed, please install via nemo_text_processing/setup.sh" ) @pytest.mark.run_only_on('CPU') @pytest.mark.unit def test_norm_cased(self, test_input, expected): pred = self.normalizer_uppercased.normalize(test_input, verbose=False) assert pred == expected if self.normalizer_with_audio_en: pred_non_deterministic = self.normalizer_with_audio_en.normalize( test_input, punct_post_process=False, n_tagged=30) assert expected in pred_non_deterministic
class DuplexDecoderModel(NLPModel): """ Transformer-based (duplex) decoder model for TN/ITN. """ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self._tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer) super().__init__(cfg=cfg, trainer=trainer) self.model = AutoModelForSeq2SeqLM.from_pretrained(cfg.transformer) self.transformer_name = cfg.transformer # Language self.lang = cfg.get('lang', None) # Covering Grammars self.cg_normalizer = None # Default # We only support integrating with English TN covering grammars at the moment self.use_cg = cfg.get('use_cg', False) and self.lang == constants.ENGLISH if self.use_cg: self.setup_cgs(cfg) # Setup covering grammars (if enabled) def setup_cgs(self, cfg: DictConfig): """ Setup covering grammars (if enabled). :param cfg: Configs of the decoder model. """ self.use_cg = True self.neural_confidence_threshold = cfg.get( 'neural_confidence_threshold', 0.99) self.n_tagged = cfg.get('n_tagged', 1) input_case = 'cased' # input_case is cased by default if hasattr(self._tokenizer, 'do_lower_case') and self._tokenizer.do_lower_case: input_case = 'lower_cased' if not PYNINI_AVAILABLE: raise Exception( "`pynini` is not installed ! \n" "Please run the `nemo_text_processing/setup.sh` script" "prior to usage of this toolkit.") self.cg_normalizer = NormalizerWithAudio(input_case=input_case, lang=self.lang) # Training def training_step(self, batch, batch_idx): """ Lightning calls this inside the training loop with the data from the training dataloader passed in as `batch`. """ # Apply Transformer outputs = self.model( input_ids=batch['input_ids'], decoder_input_ids=batch['decoder_input_ids'], attention_mask=batch['attention_mask'], labels=batch['labels'], ) train_loss = outputs.loss lr = self._optimizer.param_groups[0]['lr'] self.log('train_loss', train_loss) self.log('lr', lr, prog_bar=True) return {'loss': train_loss, 'lr': lr} # Validation and Testing def validation_step(self, batch, batch_idx): """ Lightning calls this inside the validation loop with the data from the validation dataloader passed in as `batch`. """ # Apply Transformer outputs = self.model( input_ids=batch['input_ids'], decoder_input_ids=batch['decoder_input_ids'], attention_mask=batch['attention_mask'], labels=batch['labels'], ) val_loss = outputs.loss return {'val_loss': val_loss} def validation_epoch_end(self, outputs): """ Called at the end of validation to aggregate outputs. :param outputs: list of individual outputs of each validation step. """ avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() self.log('val_loss', avg_loss) return { 'val_loss': avg_loss, } def test_step(self, batch, batch_idx): """ Lightning calls this inside the test loop with the data from the test dataloader passed in as `batch`. """ return self.validation_step(batch, batch_idx) def test_epoch_end(self, outputs): """ Called at the end of test to aggregate outputs. :param outputs: list of individual outputs of each test step. """ return self.validation_epoch_end(outputs) # Functions for inference @torch.no_grad() def _infer( self, sents: List[List[str]], nb_spans: List[int], span_starts: List[List[int]], span_ends: List[List[int]], inst_directions: List[str], ): """ Main function for Inference Args: sents: A list of inputs tokenized by a basic tokenizer. nb_spans: A list of ints where each int indicates the number of semiotic spans in each input. span_starts: A list of lists where each list contains the starting locations of semiotic spans in an input. span_ends: A list of lists where each list contains the ending locations of semiotic spans in an input. inst_directions: A list of str where each str indicates the direction of the corresponding instance (i.e., INST_BACKWARD for ITN or INST_FORWARD for TN). Returns: A list of lists where each list contains the decoded spans for the corresponding input. """ self.eval() if sum(nb_spans) == 0: return [[]] * len(sents) model, tokenizer = self.model, self._tokenizer try: model_max_len = model.config.n_positions except AttributeError: model_max_len = 512 ctx_size = constants.DECODE_CTX_SIZE extra_id_0 = constants.EXTRA_ID_0 extra_id_1 = constants.EXTRA_ID_1 # Build all_inputs input_centers, input_dirs, all_inputs = [], [], [] for ix, sent in enumerate(sents): cur_inputs = [] for jx in range(nb_spans[ix]): cur_start = span_starts[ix][jx] cur_end = span_ends[ix][jx] ctx_left = sent[max(0, cur_start - ctx_size):cur_start] ctx_right = sent[cur_end + 1:cur_end + 1 + ctx_size] span_words = sent[cur_start:cur_end + 1] span_words_str = ' '.join(span_words) if is_url(span_words_str): span_words_str = span_words_str.lower() input_centers.append(span_words_str) input_dirs.append(inst_directions[ix]) # Build cur_inputs if inst_directions[ix] == constants.INST_BACKWARD: cur_inputs = [constants.ITN_PREFIX] if inst_directions[ix] == constants.INST_FORWARD: cur_inputs = [constants.TN_PREFIX] cur_inputs += ctx_left cur_inputs += [extra_id_0 ] + span_words_str.split(' ') + [extra_id_1] cur_inputs += ctx_right all_inputs.append(' '.join(cur_inputs)) # Apply the decoding model batch = tokenizer(all_inputs, padding=True, return_tensors='pt') input_ids = batch['input_ids'].to(self.device) outputs = model.generate(input_ids, output_scores=True, return_dict_in_generate=True, max_length=model_max_len) generated_ids, sequence_toks_scores = outputs['sequences'], outputs[ 'scores'] generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) # Use covering grammars (if enabled) if self.use_cg: # Compute sequence probabilities sequence_probs = torch.ones(len(all_inputs)).to(self.device) for ix, cur_toks_scores in enumerate(sequence_toks_scores): cur_generated_ids = generated_ids[:, ix + 1].tolist() cur_toks_probs = torch.nn.functional.softmax(cur_toks_scores, dim=-1) # Compute selected_toks_probs selected_toks_probs = [] for jx, _id in enumerate(cur_generated_ids): if _id != self._tokenizer.pad_token_id: selected_toks_probs.append(cur_toks_probs[jx, _id]) else: selected_toks_probs.append(1) selected_toks_probs = torch.tensor(selected_toks_probs).to( self.device) sequence_probs *= selected_toks_probs # For TN cases where the neural model is not confident, use CGs neural_confidence_threshold = self.neural_confidence_threshold for ix, (_dir, _input, _prob) in enumerate( zip(input_dirs, input_centers, sequence_probs)): if _dir == constants.INST_FORWARD and _prob < neural_confidence_threshold: if is_url(_input): _input = _input.replace(' ', '') # Remove spaces in URLs try: cg_outputs = self.cg_normalizer.normalize( text=_input, verbose=False, n_tagged=self.n_tagged) generated_texts[ix] = list(cg_outputs)[0] except: # if there is any exception, fall back to the input generated_texts[ix] = _input # Post processing generated_texts = self.postprocess_output_spans( input_centers, generated_texts, input_dirs) # Prepare final_texts final_texts, span_ctx = [], 0 for nb_span in nb_spans: cur_texts = [] for i in range(nb_span): cur_texts.append(generated_texts[span_ctx]) span_ctx += 1 final_texts.append(cur_texts) return final_texts def postprocess_output_spans(self, input_centers, output_spans, input_dirs): en_greek_writtens = list(constants.EN_GREEK_TO_SPOKEN.keys()) en_greek_spokens = list(constants.EN_GREEK_TO_SPOKEN.values()) for ix, (_input, _output) in enumerate(zip(input_centers, output_spans)): if self.lang == constants.ENGLISH: # Handle URL if is_url(_input): _output = _output.replace('http', ' h t t p ') _output = _output.replace('/', ' slash ') _output = _output.replace('.', ' dot ') _output = _output.replace(':', ' colon ') _output = _output.replace('-', ' dash ') _output = _output.replace('_', ' underscore ') _output = _output.replace('%', ' percent ') _output = _output.replace('www', ' w w w ') _output = _output.replace('ftp', ' f t p ') output_spans[ix] = ' '.join(wordninja.split(_output)) continue # Greek letters if _input in en_greek_writtens: if input_dirs[ix] == constants.INST_FORWARD: output_spans[ix] = constants.EN_GREEK_TO_SPOKEN[_input] if _input in en_greek_spokens: if input_dirs[ix] == constants.INST_FORWARD: output_spans[ix] = _input if input_dirs[ix] == constants.INST_BACKWARD: output_spans[ix] = constants.EN_SPOKEN_TO_GREEK[_input] return output_spans # Functions for processing data def setup_training_data(self, train_data_config: Optional[DictConfig]): if not train_data_config or not train_data_config.data_path: logging.info( f"Dataloader config or file_path for the train is missing, so no data loader for train is created!" ) self.train_dataset, self._train_dl = None, None return self.train_dataset, self._train_dl = self._setup_dataloader_from_config( cfg=train_data_config, mode="train") def setup_validation_data(self, val_data_config: Optional[DictConfig]): if not val_data_config or not val_data_config.data_path: logging.info( f"Dataloader config or file_path for the validation is missing, so no data loader for validation is created!" ) self.validation_dataset, self._validation_dl = None, None return self.validation_dataset, self._validation_dl = self._setup_dataloader_from_config( cfg=val_data_config, mode="val") def setup_test_data(self, test_data_config: Optional[DictConfig]): if not test_data_config or test_data_config.data_path is None: logging.info( f"Dataloader config or file_path for the test is missing, so no data loader for test is created!" ) self.test_dataset, self._test_dl = None, None return self.test_dataset, self._test_dl = self._setup_dataloader_from_config( cfg=test_data_config, mode="test") def _setup_dataloader_from_config(self, cfg: DictConfig, mode: str): tokenizer, model = self._tokenizer, self.model start_time = perf_counter() logging.info(f'Creating {mode} dataset') input_file = cfg.data_path dataset = TextNormalizationDecoderDataset( input_file, tokenizer, self.transformer_name, cfg.mode, cfg.get('max_decoder_len', tokenizer.model_max_length), cfg.get('decoder_data_augmentation', False), cfg.lang, cfg.do_basic_tokenize, cfg.get('use_cache', False), cfg.get('max_insts', -1), ) data_collator = DataCollatorForSeq2Seq( tokenizer, model=model, label_pad_token_id=constants.LABEL_PAD_TOKEN_ID, ) dl = torch.utils.data.DataLoader(dataset=dataset, batch_size=cfg.batch_size, shuffle=cfg.shuffle, collate_fn=data_collator) running_time = perf_counter() - start_time logging.info(f'Took {running_time} seconds') return dataset, dl @classmethod def list_available_models(cls) -> Optional[PretrainedModelInfo]: """ This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. Returns: List of available pre-trained models. """ result = [] return result