예제 #1
0
def main(cfg: DictConfig) -> None:
    lang = cfg.lang

    if cfg.tagger_pretrained_model:
        tagger_trainer, tagger_model = instantiate_model_and_trainer(cfg, TAGGER_MODEL, False)
        tagger_model.max_sequence_len = 512
        tagger_model.setup_test_data(cfg.data.test_ds)
        logging.info('Evaluating the tagger...')
        tagger_trainer.test(model=tagger_model, verbose=False)
    else:
        logging.info('Tagger checkpoint is not provided, skipping tagger evaluation')

    if cfg.decoder_pretrained_model:
        decoder_trainer, decoder_model = instantiate_model_and_trainer(cfg, DECODER_MODEL, False)
        decoder_model.max_sequence_len = 512
        decoder_model.setup_multiple_test_data(cfg.data.test_ds)
        logging.info('Evaluating the decoder...')
        decoder_trainer.test(decoder_model)
    else:
        logging.info('Decoder checkpoint is not provided, skipping decoder evaluation')

    if cfg.tagger_pretrained_model and cfg.decoder_pretrained_model:
        logging.info('Running evaluation of the duplex model (tagger + decoder) on the test set.')
        tn_model = DuplexTextNormalizationModel(tagger_model, decoder_model, lang)
        test_dataset = TextNormalizationTestDataset(cfg.data.test_ds.data_path, cfg.mode, lang)
        results = tn_model.evaluate(test_dataset, cfg.data.test_ds.batch_size, cfg.data.test_ds.errors_log_fp)
        print(f'\nTest results: {results}')
def main(cfg: DictConfig) -> None:
    logging.info(f'Config Params: {OmegaConf.to_yaml(cfg)}')
    tagger_trainer, tagger_model = instantiate_model_and_trainer(
        cfg, TAGGER_MODEL, False)
    decoder_trainer, decoder_model = instantiate_model_and_trainer(
        cfg, DECODER_MODEL, False)
    tn_model = DuplexTextNormalizationModel(tagger_model, decoder_model)

    if not cfg.inference.interactive:
        # Setup test_dataset
        test_dataset = TextNormalizationTestDataset(cfg.data.test_ds.data_path,
                                                    cfg.data.test_ds.mode)
        results = tn_model.evaluate(test_dataset, cfg.data.test_ds.batch_size,
                                    cfg.inference.errors_log_fp)
        print(f'\nTest results: {results}')
    else:
        while True:
            test_input = input('Input a test input:')
            test_input = ' '.join(word_tokenize(test_input))
            outputs = tn_model._infer(
                [test_input, test_input],
                [constants.INST_BACKWARD, constants.INST_FORWARD])[-1]
            print(f'Prediction (ITN): {outputs[0]}')
            print(f'Prediction (TN): {outputs[1]}')

            should_continue = input('\nContinue (y/n): ').strip().lower()
            if should_continue.startswith('n'):
                break
예제 #3
0
def main(cfg: DictConfig) -> None:
    logging.info(f'Config Params: {OmegaConf.to_yaml(cfg)}')

    # Train the tagger
    if cfg.tagger_model.do_training:
        logging.info(
            "================================================================================================"
        )
        logging.info('Starting training tagger...')
        tagger_trainer, tagger_model = instantiate_model_and_trainer(
            cfg, TAGGER_MODEL, True)
        exp_manager(tagger_trainer, cfg.get('tagger_exp_manager', None))
        tagger_trainer.fit(tagger_model)
        if cfg.tagger_model.nemo_path:
            tagger_model.to(tagger_trainer.accelerator.root_device)
            tagger_model.save_to(cfg.tagger_model.nemo_path)
        logging.info('Training finished!')

    # Train the decoder
    if cfg.decoder_model.do_training:
        logging.info(
            "================================================================================================"
        )
        logging.info('Starting training decoder...')
        decoder_trainer, decoder_model = instantiate_model_and_trainer(
            cfg, DECODER_MODEL, True)
        exp_manager(decoder_trainer, cfg.get('decoder_exp_manager', None))
        decoder_trainer.fit(decoder_model)
        if cfg.decoder_model.nemo_path:
            decoder_model.to(decoder_trainer.accelerator.root_device)
            decoder_model.save_to(cfg.decoder_model.nemo_path)
        logging.info('Training finished!')
예제 #4
0
def main(cfg: DictConfig) -> None:
    logging.info(f'Config Params: {OmegaConf.to_yaml(cfg)}')
    lang, batch_size = cfg.lang, cfg.data.test_ds.batch_size
    tagger_trainer, tagger_model = instantiate_model_and_trainer(
        cfg, TAGGER_MODEL, False)
    decoder_trainer, decoder_model = instantiate_model_and_trainer(
        cfg, DECODER_MODEL, False)

    # Evaluating the tagger
    print('Evaluating the tagger')
    tagger_model.setup_test_data(cfg.data.test_ds)
    tagger_trainer.test(model=tagger_model, verbose=False)

    # Evaluating the decoder
    print('Evaluating the decoder')
    transformer_model, tokenizer = decoder_model.model, decoder_model._tokenizer
    try:
        model_max_len = transformer_model.config.n_positions
    except AttributeError:
        model_max_len = 512
    # Load the test dataset
    decoder_model.setup_test_data(cfg.data.test_ds)
    test_dataset, test_dl = decoder_model.test_dataset, decoder_model._test_dl
    # Inference
    itn_class2stats, tn_class2stats = {}, {}
    for ix, examples in tqdm(enumerate(test_dl)):
        # Extract infos of the current batch
        start_idx = ix * batch_size
        end_idx = min((ix + 1) * batch_size, len(test_dataset))
        batch_insts = test_dataset.insts[start_idx:end_idx]
        batch_input_centers = [inst.input_center_str for inst in batch_insts]
        batch_targets = [inst.output_str for inst in batch_insts]
        batch_dirs = [inst.direction for inst in batch_insts]
        batch_classes = [inst.semiotic_class for inst in batch_insts]
        # Inference
        input_ids = examples['input_ids'].to(decoder_model.device)
        generated_ids = transformer_model.generate(input_ids,
                                                   max_length=model_max_len)
        batch_preds = tokenizer.batch_decode(generated_ids,
                                             skip_special_tokens=True)
        batch_preds = decoder_model.postprocess_output_spans(
            batch_input_centers, batch_preds, batch_dirs)
        # Update itn_class2stats and tn_class2stats
        for direction, _class, pred, target in zip(batch_dirs, batch_classes,
                                                   batch_preds, batch_targets):
            correct = TextNormalizationTestDataset.is_same(
                pred, target, direction, lang)
            stats = itn_class2stats if direction == constants.INST_BACKWARD else tn_class2stats
            if not _class in stats:
                stats[_class] = []
            stats[_class].append(int(correct))

    # Print out stats
    print('ITN (Backward Direction)')
    print_class_based_stats(itn_class2stats)
    print('TN (Forward Direction)')
    print_class_based_stats(tn_class2stats)
예제 #5
0
def main(cfg: DictConfig) -> None:
    logging.info(f'Config Params: {OmegaConf.to_yaml(cfg)}')

    # Train the tagger
    if cfg.tagger_model.do_training:
        logging.info(
            "================================================================================================"
        )
        logging.info('Starting training tagger...')
        tagger_trainer, tagger_model = instantiate_model_and_trainer(cfg, TAGGER_MODEL, True)
        tagger_exp_manager = cfg.get('tagger_exp_manager', None)
        exp_manager(tagger_trainer, tagger_exp_manager)
        tagger_trainer.fit(tagger_model)
        if (
            tagger_exp_manager
            and tagger_exp_manager.get('create_checkpoint_callback', False)
            and cfg.tagger_model.nemo_path
        ):
            tagger_model.to(tagger_trainer.accelerator.root_device)
            tagger_model.save_to(cfg.tagger_model.nemo_path)
        logging.info('Training finished!')

    # Train the decoder
    if cfg.decoder_model.do_training:
        logging.info(
            "================================================================================================"
        )
        logging.info('Starting training decoder...')
        decoder_trainer, decoder_model = instantiate_model_and_trainer(cfg, DECODER_MODEL, True)
        decoder_exp_manager = cfg.get('decoder_exp_manager', None)
        exp_manager(decoder_trainer, decoder_exp_manager)
        decoder_trainer.fit(decoder_model)
        if (
            decoder_exp_manager
            and decoder_exp_manager.get('create_checkpoint_callback', False)
            and cfg.decoder_model.nemo_path
        ):
            decoder_model.to(decoder_trainer.accelerator.root_device)
            decoder_model.save_to(cfg.decoder_model.nemo_path)
        logging.info('Training finished!')

    # Evaluation after training
    if (
        hasattr(cfg.data, 'test_ds')
        and cfg.data.test_ds.data_path is not None
        and cfg.tagger_model.do_training
        and cfg.decoder_model.do_training
    ):
        tn_model = DuplexTextNormalizationModel(tagger_model, decoder_model, cfg.lang)
        test_dataset = TextNormalizationTestDataset(cfg.data.test_ds.data_path, cfg.mode, cfg.lang)
        results = tn_model.evaluate(test_dataset, cfg.data.test_ds.batch_size, cfg.data.test_ds.errors_log_fp)
        print(f'\nTest results: {results}')
예제 #6
0
def main(cfg: DictConfig) -> None:
    logging.info(f'Config Params: {OmegaConf.to_yaml(cfg)}')

    # Train the model
    if cfg.model.do_training:
        logging.info(
            "================================================================================================"
        )
        logging.info('Start training...')
        trainer, model = instantiate_model_and_trainer(cfg, ITN_MODEL, True)
        thutmose_tagger_exp_manager = cfg.get('exp_manager', None)
        exp_manager(trainer, thutmose_tagger_exp_manager)
        trainer.fit(model)
        logging.info('Training finished!')
예제 #7
0
def main(cfg: DictConfig) -> None:
    logging.debug(f'Config Params: {OmegaConf.to_yaml(cfg)}')

    if cfg.pretrained_model is None:
        raise ValueError("A pre-trained model should be provided.")
    _, model = instantiate_model_and_trainer(cfg, ITN_MODEL, False)

    text_file = cfg.inference.from_file
    logging.info(f"Running inference on {text_file}...")
    if not os.path.exists(text_file):
        raise ValueError(f"{text_file} not found.")

    with open(text_file, "r", encoding="utf-8") as f:
        lines = f.readlines()

    batch_size = cfg.inference.get("batch_size", 8)

    batch, all_preds = [], []
    for i, line in enumerate(lines):
        s = spoken_preprocessing(
            line
        )  # this is the same input transformation as in corpus preparation
        batch.append(s.strip())
        if len(batch) == batch_size or i == len(lines) - 1:
            outputs = model._infer(batch)
            for x in outputs:
                all_preds.append(x)
            batch = []
    if len(all_preds) != len(lines):
        raise ValueError(
            "number of input lines and predictions is different: predictions="
            + str(len(all_preds)) + "; lines=" + str(len(lines)))
    out_file = cfg.inference.out_file
    with open(f"{out_file}", "w", encoding="utf-8") as f_out:
        f_out.write("\n".join(all_preds))
    logging.info(f"Predictions saved to {out_file}.")
예제 #8
0
def main(cfg: DictConfig) -> None:
    logging.info(f'Config Params: {OmegaConf.to_yaml(cfg)}')
    lang = cfg.lang
    do_basic_tokenization = True
    if cfg.decoder_pretrained_model is None or cfg.tagger_pretrained_model is None:
        raise ValueError(
            "Both pre-trained models (DuplexTaggerModel and DuplexDecoderModel) should be provided."
        )
    tagger_trainer, tagger_model = instantiate_model_and_trainer(
        cfg, TAGGER_MODEL, False)
    decoder_trainer, decoder_model = instantiate_model_and_trainer(
        cfg, DECODER_MODEL, False)
    tn_model = DuplexTextNormalizationModel(tagger_model, decoder_model, lang)

    if cfg.inference.get("from_file", False):
        text_file = cfg.inference.from_file
        logging.info(f'Running inference on {text_file}...')
        if not os.path.exists(text_file):
            raise ValueError(f'{text_file} not found.')

        with open(text_file, 'r') as f:
            lines = f.readlines()

        def _get_predictions(lines: List[str], mode: str, batch_size: int,
                             text_file: str):
            """ Runs inference on a batch data without labels and saved predictions to a file. """
            assert mode in ['tn', 'itn']
            file_name, extension = os.path.splitext(text_file)
            batch, all_preds = [], []
            for i, line in enumerate(lines):
                batch.append(line.strip())
                if len(batch) == batch_size or i == len(lines) - 1:
                    outputs = tn_model._infer(
                        batch,
                        [constants.DIRECTIONS_TO_MODE[mode]] * len(batch),
                        do_basic_tokenization=do_basic_tokenization,
                    )
                    all_preds.extend([x for x in outputs[-1]])
                    batch = []
            assert len(all_preds) == len(lines)
            out_file = f'{file_name}_{mode}{extension}'
            with open(f'{out_file}', 'w') as f_out:
                f_out.write("\n".join(all_preds))
            logging.info(f'Predictions for {mode} save to {out_file}.')

        batch_size = cfg.inference.get("batch_size", 8)
        if cfg.mode in ['tn', 'joint']:
            # TN mode
            _get_predictions(lines, 'tn', batch_size, text_file)
        if cfg.mode in ['itn', 'joint']:
            # ITN mode
            _get_predictions(lines, 'itn', batch_size, text_file)

    else:
        print('Entering interactive mode.')
        done = False
        while not done:
            print('Type "STOP" to exit.')
            test_input = input('Input a test input:')
            if test_input == "STOP":
                done = True
            if not done:
                directions = []
                inputs = []
                if cfg.mode in ['itn', 'joint']:
                    directions.append(
                        constants.DIRECTIONS_TO_MODE[constants.ITN_MODE])
                    inputs.append(test_input)
                if cfg.mode in ['tn', 'joint']:
                    directions.append(
                        constants.DIRECTIONS_TO_MODE[constants.TN_MODE])
                    inputs.append(test_input)
                outputs = tn_model._infer(
                    inputs,
                    directions,
                    do_basic_tokenization=do_basic_tokenization,
                )[-1]
                if cfg.mode in ['joint', 'itn']:
                    print(f'Prediction (ITN): {outputs[0]}')
                if cfg.mode in ['joint', 'tn']:
                    print(f'Prediction (TN): {outputs[-1]}')