def main(cfg: DictConfig) -> None: lang = cfg.lang if cfg.tagger_pretrained_model: tagger_trainer, tagger_model = instantiate_model_and_trainer(cfg, TAGGER_MODEL, False) tagger_model.max_sequence_len = 512 tagger_model.setup_test_data(cfg.data.test_ds) logging.info('Evaluating the tagger...') tagger_trainer.test(model=tagger_model, verbose=False) else: logging.info('Tagger checkpoint is not provided, skipping tagger evaluation') if cfg.decoder_pretrained_model: decoder_trainer, decoder_model = instantiate_model_and_trainer(cfg, DECODER_MODEL, False) decoder_model.max_sequence_len = 512 decoder_model.setup_multiple_test_data(cfg.data.test_ds) logging.info('Evaluating the decoder...') decoder_trainer.test(decoder_model) else: logging.info('Decoder checkpoint is not provided, skipping decoder evaluation') if cfg.tagger_pretrained_model and cfg.decoder_pretrained_model: logging.info('Running evaluation of the duplex model (tagger + decoder) on the test set.') tn_model = DuplexTextNormalizationModel(tagger_model, decoder_model, lang) test_dataset = TextNormalizationTestDataset(cfg.data.test_ds.data_path, cfg.mode, lang) results = tn_model.evaluate(test_dataset, cfg.data.test_ds.batch_size, cfg.data.test_ds.errors_log_fp) print(f'\nTest results: {results}')
def main(cfg: DictConfig) -> None: logging.info(f'Config Params: {OmegaConf.to_yaml(cfg)}') tagger_trainer, tagger_model = instantiate_model_and_trainer( cfg, TAGGER_MODEL, False) decoder_trainer, decoder_model = instantiate_model_and_trainer( cfg, DECODER_MODEL, False) tn_model = DuplexTextNormalizationModel(tagger_model, decoder_model) if not cfg.inference.interactive: # Setup test_dataset test_dataset = TextNormalizationTestDataset(cfg.data.test_ds.data_path, cfg.data.test_ds.mode) results = tn_model.evaluate(test_dataset, cfg.data.test_ds.batch_size, cfg.inference.errors_log_fp) print(f'\nTest results: {results}') else: while True: test_input = input('Input a test input:') test_input = ' '.join(word_tokenize(test_input)) outputs = tn_model._infer( [test_input, test_input], [constants.INST_BACKWARD, constants.INST_FORWARD])[-1] print(f'Prediction (ITN): {outputs[0]}') print(f'Prediction (TN): {outputs[1]}') should_continue = input('\nContinue (y/n): ').strip().lower() if should_continue.startswith('n'): break
def main(cfg: DictConfig) -> None: logging.info(f'Config Params: {OmegaConf.to_yaml(cfg)}') # Train the tagger if cfg.tagger_model.do_training: logging.info( "================================================================================================" ) logging.info('Starting training tagger...') tagger_trainer, tagger_model = instantiate_model_and_trainer( cfg, TAGGER_MODEL, True) exp_manager(tagger_trainer, cfg.get('tagger_exp_manager', None)) tagger_trainer.fit(tagger_model) if cfg.tagger_model.nemo_path: tagger_model.to(tagger_trainer.accelerator.root_device) tagger_model.save_to(cfg.tagger_model.nemo_path) logging.info('Training finished!') # Train the decoder if cfg.decoder_model.do_training: logging.info( "================================================================================================" ) logging.info('Starting training decoder...') decoder_trainer, decoder_model = instantiate_model_and_trainer( cfg, DECODER_MODEL, True) exp_manager(decoder_trainer, cfg.get('decoder_exp_manager', None)) decoder_trainer.fit(decoder_model) if cfg.decoder_model.nemo_path: decoder_model.to(decoder_trainer.accelerator.root_device) decoder_model.save_to(cfg.decoder_model.nemo_path) logging.info('Training finished!')
def main(cfg: DictConfig) -> None: logging.info(f'Config Params: {OmegaConf.to_yaml(cfg)}') lang, batch_size = cfg.lang, cfg.data.test_ds.batch_size tagger_trainer, tagger_model = instantiate_model_and_trainer( cfg, TAGGER_MODEL, False) decoder_trainer, decoder_model = instantiate_model_and_trainer( cfg, DECODER_MODEL, False) # Evaluating the tagger print('Evaluating the tagger') tagger_model.setup_test_data(cfg.data.test_ds) tagger_trainer.test(model=tagger_model, verbose=False) # Evaluating the decoder print('Evaluating the decoder') transformer_model, tokenizer = decoder_model.model, decoder_model._tokenizer try: model_max_len = transformer_model.config.n_positions except AttributeError: model_max_len = 512 # Load the test dataset decoder_model.setup_test_data(cfg.data.test_ds) test_dataset, test_dl = decoder_model.test_dataset, decoder_model._test_dl # Inference itn_class2stats, tn_class2stats = {}, {} for ix, examples in tqdm(enumerate(test_dl)): # Extract infos of the current batch start_idx = ix * batch_size end_idx = min((ix + 1) * batch_size, len(test_dataset)) batch_insts = test_dataset.insts[start_idx:end_idx] batch_input_centers = [inst.input_center_str for inst in batch_insts] batch_targets = [inst.output_str for inst in batch_insts] batch_dirs = [inst.direction for inst in batch_insts] batch_classes = [inst.semiotic_class for inst in batch_insts] # Inference input_ids = examples['input_ids'].to(decoder_model.device) generated_ids = transformer_model.generate(input_ids, max_length=model_max_len) batch_preds = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) batch_preds = decoder_model.postprocess_output_spans( batch_input_centers, batch_preds, batch_dirs) # Update itn_class2stats and tn_class2stats for direction, _class, pred, target in zip(batch_dirs, batch_classes, batch_preds, batch_targets): correct = TextNormalizationTestDataset.is_same( pred, target, direction, lang) stats = itn_class2stats if direction == constants.INST_BACKWARD else tn_class2stats if not _class in stats: stats[_class] = [] stats[_class].append(int(correct)) # Print out stats print('ITN (Backward Direction)') print_class_based_stats(itn_class2stats) print('TN (Forward Direction)') print_class_based_stats(tn_class2stats)
def main(cfg: DictConfig) -> None: logging.info(f'Config Params: {OmegaConf.to_yaml(cfg)}') # Train the tagger if cfg.tagger_model.do_training: logging.info( "================================================================================================" ) logging.info('Starting training tagger...') tagger_trainer, tagger_model = instantiate_model_and_trainer(cfg, TAGGER_MODEL, True) tagger_exp_manager = cfg.get('tagger_exp_manager', None) exp_manager(tagger_trainer, tagger_exp_manager) tagger_trainer.fit(tagger_model) if ( tagger_exp_manager and tagger_exp_manager.get('create_checkpoint_callback', False) and cfg.tagger_model.nemo_path ): tagger_model.to(tagger_trainer.accelerator.root_device) tagger_model.save_to(cfg.tagger_model.nemo_path) logging.info('Training finished!') # Train the decoder if cfg.decoder_model.do_training: logging.info( "================================================================================================" ) logging.info('Starting training decoder...') decoder_trainer, decoder_model = instantiate_model_and_trainer(cfg, DECODER_MODEL, True) decoder_exp_manager = cfg.get('decoder_exp_manager', None) exp_manager(decoder_trainer, decoder_exp_manager) decoder_trainer.fit(decoder_model) if ( decoder_exp_manager and decoder_exp_manager.get('create_checkpoint_callback', False) and cfg.decoder_model.nemo_path ): decoder_model.to(decoder_trainer.accelerator.root_device) decoder_model.save_to(cfg.decoder_model.nemo_path) logging.info('Training finished!') # Evaluation after training if ( hasattr(cfg.data, 'test_ds') and cfg.data.test_ds.data_path is not None and cfg.tagger_model.do_training and cfg.decoder_model.do_training ): tn_model = DuplexTextNormalizationModel(tagger_model, decoder_model, cfg.lang) test_dataset = TextNormalizationTestDataset(cfg.data.test_ds.data_path, cfg.mode, cfg.lang) results = tn_model.evaluate(test_dataset, cfg.data.test_ds.batch_size, cfg.data.test_ds.errors_log_fp) print(f'\nTest results: {results}')
def main(cfg: DictConfig) -> None: logging.info(f'Config Params: {OmegaConf.to_yaml(cfg)}') # Train the model if cfg.model.do_training: logging.info( "================================================================================================" ) logging.info('Start training...') trainer, model = instantiate_model_and_trainer(cfg, ITN_MODEL, True) thutmose_tagger_exp_manager = cfg.get('exp_manager', None) exp_manager(trainer, thutmose_tagger_exp_manager) trainer.fit(model) logging.info('Training finished!')
def main(cfg: DictConfig) -> None: logging.debug(f'Config Params: {OmegaConf.to_yaml(cfg)}') if cfg.pretrained_model is None: raise ValueError("A pre-trained model should be provided.") _, model = instantiate_model_and_trainer(cfg, ITN_MODEL, False) text_file = cfg.inference.from_file logging.info(f"Running inference on {text_file}...") if not os.path.exists(text_file): raise ValueError(f"{text_file} not found.") with open(text_file, "r", encoding="utf-8") as f: lines = f.readlines() batch_size = cfg.inference.get("batch_size", 8) batch, all_preds = [], [] for i, line in enumerate(lines): s = spoken_preprocessing( line ) # this is the same input transformation as in corpus preparation batch.append(s.strip()) if len(batch) == batch_size or i == len(lines) - 1: outputs = model._infer(batch) for x in outputs: all_preds.append(x) batch = [] if len(all_preds) != len(lines): raise ValueError( "number of input lines and predictions is different: predictions=" + str(len(all_preds)) + "; lines=" + str(len(lines))) out_file = cfg.inference.out_file with open(f"{out_file}", "w", encoding="utf-8") as f_out: f_out.write("\n".join(all_preds)) logging.info(f"Predictions saved to {out_file}.")
def main(cfg: DictConfig) -> None: logging.info(f'Config Params: {OmegaConf.to_yaml(cfg)}') lang = cfg.lang do_basic_tokenization = True if cfg.decoder_pretrained_model is None or cfg.tagger_pretrained_model is None: raise ValueError( "Both pre-trained models (DuplexTaggerModel and DuplexDecoderModel) should be provided." ) tagger_trainer, tagger_model = instantiate_model_and_trainer( cfg, TAGGER_MODEL, False) decoder_trainer, decoder_model = instantiate_model_and_trainer( cfg, DECODER_MODEL, False) tn_model = DuplexTextNormalizationModel(tagger_model, decoder_model, lang) if cfg.inference.get("from_file", False): text_file = cfg.inference.from_file logging.info(f'Running inference on {text_file}...') if not os.path.exists(text_file): raise ValueError(f'{text_file} not found.') with open(text_file, 'r') as f: lines = f.readlines() def _get_predictions(lines: List[str], mode: str, batch_size: int, text_file: str): """ Runs inference on a batch data without labels and saved predictions to a file. """ assert mode in ['tn', 'itn'] file_name, extension = os.path.splitext(text_file) batch, all_preds = [], [] for i, line in enumerate(lines): batch.append(line.strip()) if len(batch) == batch_size or i == len(lines) - 1: outputs = tn_model._infer( batch, [constants.DIRECTIONS_TO_MODE[mode]] * len(batch), do_basic_tokenization=do_basic_tokenization, ) all_preds.extend([x for x in outputs[-1]]) batch = [] assert len(all_preds) == len(lines) out_file = f'{file_name}_{mode}{extension}' with open(f'{out_file}', 'w') as f_out: f_out.write("\n".join(all_preds)) logging.info(f'Predictions for {mode} save to {out_file}.') batch_size = cfg.inference.get("batch_size", 8) if cfg.mode in ['tn', 'joint']: # TN mode _get_predictions(lines, 'tn', batch_size, text_file) if cfg.mode in ['itn', 'joint']: # ITN mode _get_predictions(lines, 'itn', batch_size, text_file) else: print('Entering interactive mode.') done = False while not done: print('Type "STOP" to exit.') test_input = input('Input a test input:') if test_input == "STOP": done = True if not done: directions = [] inputs = [] if cfg.mode in ['itn', 'joint']: directions.append( constants.DIRECTIONS_TO_MODE[constants.ITN_MODE]) inputs.append(test_input) if cfg.mode in ['tn', 'joint']: directions.append( constants.DIRECTIONS_TO_MODE[constants.TN_MODE]) inputs.append(test_input) outputs = tn_model._infer( inputs, directions, do_basic_tokenization=do_basic_tokenization, )[-1] if cfg.mode in ['joint', 'itn']: print(f'Prediction (ITN): {outputs[0]}') if cfg.mode in ['joint', 'tn']: print(f'Prediction (TN): {outputs[-1]}')