def main(cfg: DictConfig) -> None:
    logging.info(f'Config Params: {OmegaConf.to_yaml(cfg)}')
    tagger_trainer, tagger_model = instantiate_model_and_trainer(
        cfg, TAGGER_MODEL, False)
    decoder_trainer, decoder_model = instantiate_model_and_trainer(
        cfg, DECODER_MODEL, False)
    tn_model = DuplexTextNormalizationModel(tagger_model, decoder_model)

    if not cfg.inference.interactive:
        # Setup test_dataset
        test_dataset = TextNormalizationTestDataset(cfg.data.test_ds.data_path,
                                                    cfg.data.test_ds.mode)
        results = tn_model.evaluate(test_dataset, cfg.data.test_ds.batch_size,
                                    cfg.inference.errors_log_fp)
        print(f'\nTest results: {results}')
    else:
        while True:
            test_input = input('Input a test input:')
            test_input = ' '.join(word_tokenize(test_input))
            outputs = tn_model._infer(
                [test_input, test_input],
                [constants.INST_BACKWARD, constants.INST_FORWARD])[-1]
            print(f'Prediction (ITN): {outputs[0]}')
            print(f'Prediction (TN): {outputs[1]}')

            should_continue = input('\nContinue (y/n): ').strip().lower()
            if should_continue.startswith('n'):
                break
Exemple #2
0
def main(cfg: DictConfig) -> None:
    lang = cfg.lang

    if cfg.tagger_pretrained_model:
        tagger_trainer, tagger_model = instantiate_model_and_trainer(cfg, TAGGER_MODEL, False)
        tagger_model.max_sequence_len = 512
        tagger_model.setup_test_data(cfg.data.test_ds)
        logging.info('Evaluating the tagger...')
        tagger_trainer.test(model=tagger_model, verbose=False)
    else:
        logging.info('Tagger checkpoint is not provided, skipping tagger evaluation')

    if cfg.decoder_pretrained_model:
        decoder_trainer, decoder_model = instantiate_model_and_trainer(cfg, DECODER_MODEL, False)
        decoder_model.max_sequence_len = 512
        decoder_model.setup_multiple_test_data(cfg.data.test_ds)
        logging.info('Evaluating the decoder...')
        decoder_trainer.test(decoder_model)
    else:
        logging.info('Decoder checkpoint is not provided, skipping decoder evaluation')

    if cfg.tagger_pretrained_model and cfg.decoder_pretrained_model:
        logging.info('Running evaluation of the duplex model (tagger + decoder) on the test set.')
        tn_model = DuplexTextNormalizationModel(tagger_model, decoder_model, lang)
        test_dataset = TextNormalizationTestDataset(cfg.data.test_ds.data_path, cfg.mode, lang)
        results = tn_model.evaluate(test_dataset, cfg.data.test_ds.batch_size, cfg.data.test_ds.errors_log_fp)
        print(f'\nTest results: {results}')
Exemple #3
0
def main(cfg: DictConfig) -> None:
    logging.info(f'Config Params: {OmegaConf.to_yaml(cfg)}')

    # Train the tagger
    if cfg.tagger_model.do_training:
        logging.info(
            "================================================================================================"
        )
        logging.info('Starting training tagger...')
        tagger_trainer, tagger_model = instantiate_model_and_trainer(cfg, TAGGER_MODEL, True)
        tagger_exp_manager = cfg.get('tagger_exp_manager', None)
        exp_manager(tagger_trainer, tagger_exp_manager)
        tagger_trainer.fit(tagger_model)
        if (
            tagger_exp_manager
            and tagger_exp_manager.get('create_checkpoint_callback', False)
            and cfg.tagger_model.nemo_path
        ):
            tagger_model.to(tagger_trainer.accelerator.root_device)
            tagger_model.save_to(cfg.tagger_model.nemo_path)
        logging.info('Training finished!')

    # Train the decoder
    if cfg.decoder_model.do_training:
        logging.info(
            "================================================================================================"
        )
        logging.info('Starting training decoder...')
        decoder_trainer, decoder_model = instantiate_model_and_trainer(cfg, DECODER_MODEL, True)
        decoder_exp_manager = cfg.get('decoder_exp_manager', None)
        exp_manager(decoder_trainer, decoder_exp_manager)
        decoder_trainer.fit(decoder_model)
        if (
            decoder_exp_manager
            and decoder_exp_manager.get('create_checkpoint_callback', False)
            and cfg.decoder_model.nemo_path
        ):
            decoder_model.to(decoder_trainer.accelerator.root_device)
            decoder_model.save_to(cfg.decoder_model.nemo_path)
        logging.info('Training finished!')

    # Evaluation after training
    if (
        hasattr(cfg.data, 'test_ds')
        and cfg.data.test_ds.data_path is not None
        and cfg.tagger_model.do_training
        and cfg.decoder_model.do_training
    ):
        tn_model = DuplexTextNormalizationModel(tagger_model, decoder_model, cfg.lang)
        test_dataset = TextNormalizationTestDataset(cfg.data.test_ds.data_path, cfg.mode, cfg.lang)
        results = tn_model.evaluate(test_dataset, cfg.data.test_ds.batch_size, cfg.data.test_ds.errors_log_fp)
        print(f'\nTest results: {results}')
Exemple #4
0
def main(cfg: DictConfig) -> None:
    logging.info(f'Config Params: {OmegaConf.to_yaml(cfg)}')
    lang = cfg.lang
    do_basic_tokenization = True
    if cfg.decoder_pretrained_model is None or cfg.tagger_pretrained_model is None:
        raise ValueError(
            "Both pre-trained models (DuplexTaggerModel and DuplexDecoderModel) should be provided."
        )
    tagger_trainer, tagger_model = instantiate_model_and_trainer(
        cfg, TAGGER_MODEL, False)
    decoder_trainer, decoder_model = instantiate_model_and_trainer(
        cfg, DECODER_MODEL, False)
    tn_model = DuplexTextNormalizationModel(tagger_model, decoder_model, lang)

    if cfg.inference.get("from_file", False):
        text_file = cfg.inference.from_file
        logging.info(f'Running inference on {text_file}...')
        if not os.path.exists(text_file):
            raise ValueError(f'{text_file} not found.')

        with open(text_file, 'r') as f:
            lines = f.readlines()

        def _get_predictions(lines: List[str], mode: str, batch_size: int,
                             text_file: str):
            """ Runs inference on a batch data without labels and saved predictions to a file. """
            assert mode in ['tn', 'itn']
            file_name, extension = os.path.splitext(text_file)
            batch, all_preds = [], []
            for i, line in enumerate(lines):
                batch.append(line.strip())
                if len(batch) == batch_size or i == len(lines) - 1:
                    outputs = tn_model._infer(
                        batch,
                        [constants.DIRECTIONS_TO_MODE[mode]] * len(batch),
                        do_basic_tokenization=do_basic_tokenization,
                    )
                    all_preds.extend([x for x in outputs[-1]])
                    batch = []
            assert len(all_preds) == len(lines)
            out_file = f'{file_name}_{mode}{extension}'
            with open(f'{out_file}', 'w') as f_out:
                f_out.write("\n".join(all_preds))
            logging.info(f'Predictions for {mode} save to {out_file}.')

        batch_size = cfg.inference.get("batch_size", 8)
        if cfg.mode in ['tn', 'joint']:
            # TN mode
            _get_predictions(lines, 'tn', batch_size, text_file)
        if cfg.mode in ['itn', 'joint']:
            # ITN mode
            _get_predictions(lines, 'itn', batch_size, text_file)

    else:
        print('Entering interactive mode.')
        done = False
        while not done:
            print('Type "STOP" to exit.')
            test_input = input('Input a test input:')
            if test_input == "STOP":
                done = True
            if not done:
                directions = []
                inputs = []
                if cfg.mode in ['itn', 'joint']:
                    directions.append(
                        constants.DIRECTIONS_TO_MODE[constants.ITN_MODE])
                    inputs.append(test_input)
                if cfg.mode in ['tn', 'joint']:
                    directions.append(
                        constants.DIRECTIONS_TO_MODE[constants.TN_MODE])
                    inputs.append(test_input)
                outputs = tn_model._infer(
                    inputs,
                    directions,
                    do_basic_tokenization=do_basic_tokenization,
                )[-1]
                if cfg.mode in ['joint', 'itn']:
                    print(f'Prediction (ITN): {outputs[0]}')
                if cfg.mode in ['joint', 'tn']:
                    print(f'Prediction (TN): {outputs[-1]}')