def main(cfg: DictConfig) -> None: logging.info(f'Config Params: {OmegaConf.to_yaml(cfg)}') tagger_trainer, tagger_model = instantiate_model_and_trainer( cfg, TAGGER_MODEL, False) decoder_trainer, decoder_model = instantiate_model_and_trainer( cfg, DECODER_MODEL, False) tn_model = DuplexTextNormalizationModel(tagger_model, decoder_model) if not cfg.inference.interactive: # Setup test_dataset test_dataset = TextNormalizationTestDataset(cfg.data.test_ds.data_path, cfg.data.test_ds.mode) results = tn_model.evaluate(test_dataset, cfg.data.test_ds.batch_size, cfg.inference.errors_log_fp) print(f'\nTest results: {results}') else: while True: test_input = input('Input a test input:') test_input = ' '.join(word_tokenize(test_input)) outputs = tn_model._infer( [test_input, test_input], [constants.INST_BACKWARD, constants.INST_FORWARD])[-1] print(f'Prediction (ITN): {outputs[0]}') print(f'Prediction (TN): {outputs[1]}') should_continue = input('\nContinue (y/n): ').strip().lower() if should_continue.startswith('n'): break
def main(cfg: DictConfig) -> None: lang = cfg.lang if cfg.tagger_pretrained_model: tagger_trainer, tagger_model = instantiate_model_and_trainer(cfg, TAGGER_MODEL, False) tagger_model.max_sequence_len = 512 tagger_model.setup_test_data(cfg.data.test_ds) logging.info('Evaluating the tagger...') tagger_trainer.test(model=tagger_model, verbose=False) else: logging.info('Tagger checkpoint is not provided, skipping tagger evaluation') if cfg.decoder_pretrained_model: decoder_trainer, decoder_model = instantiate_model_and_trainer(cfg, DECODER_MODEL, False) decoder_model.max_sequence_len = 512 decoder_model.setup_multiple_test_data(cfg.data.test_ds) logging.info('Evaluating the decoder...') decoder_trainer.test(decoder_model) else: logging.info('Decoder checkpoint is not provided, skipping decoder evaluation') if cfg.tagger_pretrained_model and cfg.decoder_pretrained_model: logging.info('Running evaluation of the duplex model (tagger + decoder) on the test set.') tn_model = DuplexTextNormalizationModel(tagger_model, decoder_model, lang) test_dataset = TextNormalizationTestDataset(cfg.data.test_ds.data_path, cfg.mode, lang) results = tn_model.evaluate(test_dataset, cfg.data.test_ds.batch_size, cfg.data.test_ds.errors_log_fp) print(f'\nTest results: {results}')
def main(cfg: DictConfig) -> None: logging.info(f'Config Params: {OmegaConf.to_yaml(cfg)}') # Train the tagger if cfg.tagger_model.do_training: logging.info( "================================================================================================" ) logging.info('Starting training tagger...') tagger_trainer, tagger_model = instantiate_model_and_trainer(cfg, TAGGER_MODEL, True) tagger_exp_manager = cfg.get('tagger_exp_manager', None) exp_manager(tagger_trainer, tagger_exp_manager) tagger_trainer.fit(tagger_model) if ( tagger_exp_manager and tagger_exp_manager.get('create_checkpoint_callback', False) and cfg.tagger_model.nemo_path ): tagger_model.to(tagger_trainer.accelerator.root_device) tagger_model.save_to(cfg.tagger_model.nemo_path) logging.info('Training finished!') # Train the decoder if cfg.decoder_model.do_training: logging.info( "================================================================================================" ) logging.info('Starting training decoder...') decoder_trainer, decoder_model = instantiate_model_and_trainer(cfg, DECODER_MODEL, True) decoder_exp_manager = cfg.get('decoder_exp_manager', None) exp_manager(decoder_trainer, decoder_exp_manager) decoder_trainer.fit(decoder_model) if ( decoder_exp_manager and decoder_exp_manager.get('create_checkpoint_callback', False) and cfg.decoder_model.nemo_path ): decoder_model.to(decoder_trainer.accelerator.root_device) decoder_model.save_to(cfg.decoder_model.nemo_path) logging.info('Training finished!') # Evaluation after training if ( hasattr(cfg.data, 'test_ds') and cfg.data.test_ds.data_path is not None and cfg.tagger_model.do_training and cfg.decoder_model.do_training ): tn_model = DuplexTextNormalizationModel(tagger_model, decoder_model, cfg.lang) test_dataset = TextNormalizationTestDataset(cfg.data.test_ds.data_path, cfg.mode, cfg.lang) results = tn_model.evaluate(test_dataset, cfg.data.test_ds.batch_size, cfg.data.test_ds.errors_log_fp) print(f'\nTest results: {results}')
def main(cfg: DictConfig) -> None: logging.info(f'Config Params: {OmegaConf.to_yaml(cfg)}') lang = cfg.lang do_basic_tokenization = True if cfg.decoder_pretrained_model is None or cfg.tagger_pretrained_model is None: raise ValueError( "Both pre-trained models (DuplexTaggerModel and DuplexDecoderModel) should be provided." ) tagger_trainer, tagger_model = instantiate_model_and_trainer( cfg, TAGGER_MODEL, False) decoder_trainer, decoder_model = instantiate_model_and_trainer( cfg, DECODER_MODEL, False) tn_model = DuplexTextNormalizationModel(tagger_model, decoder_model, lang) if cfg.inference.get("from_file", False): text_file = cfg.inference.from_file logging.info(f'Running inference on {text_file}...') if not os.path.exists(text_file): raise ValueError(f'{text_file} not found.') with open(text_file, 'r') as f: lines = f.readlines() def _get_predictions(lines: List[str], mode: str, batch_size: int, text_file: str): """ Runs inference on a batch data without labels and saved predictions to a file. """ assert mode in ['tn', 'itn'] file_name, extension = os.path.splitext(text_file) batch, all_preds = [], [] for i, line in enumerate(lines): batch.append(line.strip()) if len(batch) == batch_size or i == len(lines) - 1: outputs = tn_model._infer( batch, [constants.DIRECTIONS_TO_MODE[mode]] * len(batch), do_basic_tokenization=do_basic_tokenization, ) all_preds.extend([x for x in outputs[-1]]) batch = [] assert len(all_preds) == len(lines) out_file = f'{file_name}_{mode}{extension}' with open(f'{out_file}', 'w') as f_out: f_out.write("\n".join(all_preds)) logging.info(f'Predictions for {mode} save to {out_file}.') batch_size = cfg.inference.get("batch_size", 8) if cfg.mode in ['tn', 'joint']: # TN mode _get_predictions(lines, 'tn', batch_size, text_file) if cfg.mode in ['itn', 'joint']: # ITN mode _get_predictions(lines, 'itn', batch_size, text_file) else: print('Entering interactive mode.') done = False while not done: print('Type "STOP" to exit.') test_input = input('Input a test input:') if test_input == "STOP": done = True if not done: directions = [] inputs = [] if cfg.mode in ['itn', 'joint']: directions.append( constants.DIRECTIONS_TO_MODE[constants.ITN_MODE]) inputs.append(test_input) if cfg.mode in ['tn', 'joint']: directions.append( constants.DIRECTIONS_TO_MODE[constants.TN_MODE]) inputs.append(test_input) outputs = tn_model._infer( inputs, directions, do_basic_tokenization=do_basic_tokenization, )[-1] if cfg.mode in ['joint', 'itn']: print(f'Prediction (ITN): {outputs[0]}') if cfg.mode in ['joint', 'tn']: print(f'Prediction (TN): {outputs[-1]}')