def main(cfg: DictConfig) -> None:
    trainer = pl.Trainer(**cfg.trainer)
    exp_manager(trainer, cfg.get("exp_manager", None))
    do_training = True
    if not cfg.pretrained_model:
        logging.info(f'Config: {OmegaConf.to_yaml(cfg)}')
        model = PunctuationCapitalizationModel(cfg.model, trainer=trainer)
    else:
        logging.info(f'Loading pretrained model {cfg.pretrained_model}')
        # TODO: Remove strict, when lightning has persistent parameter support for add_state()
        model = PunctuationCapitalizationModel.from_pretrained(
            cfg.pretrained_model, strict=False)
        data_dir = cfg.model.dataset.get('data_dir', None)
        if data_dir:
            # we can also do finetunining of the pretrained model but it will require
            # setting up train and validation Pytorch DataLoaders
            model.setup_training_data(data_dir=data_dir)
            # evaluation could be done on multiple files, use model.validation_ds.ds_items to specify multiple
            # data directories if needed
            model.setup_validation_data(data_dirs=data_dir)
            logging.info(f'Using config file of the pretrained model')
        else:
            do_training = False
            logging.info(
                f'Data dir should be specified for training/finetuning. '
                f'Using pretrained {cfg.pretrained_model} model weights and skipping finetuning.'
            )

    if do_training:
        trainer.fit(model)
        if cfg.model.nemo_path:
            model.save_to(cfg.model.nemo_path)

    logging.info(
        'During evaluation/testing, it is currently advisable to construct a new Trainer with single GPU '
        'and no DDP to obtain accurate results')
    gpu = 1 if cfg.trainer.gpus != 0 else 0
    trainer = pl.Trainer(gpus=gpu)
    model.set_trainer(trainer)

    # run an inference on a few examples
    queries = [
        'we bought four shirts one pen and a mug from the nvidia gear store in santa clara',
        'what can i do for you today',
        'how are you',
    ]
    inference_results = model.add_punctuation_capitalization(queries)

    for query, result in zip(queries, inference_results):
        logging.info(f'Query : {query}')
        logging.info(f'Result: {result.strip()}\n')
예제 #2
0
def main(cfg: DictConfig) -> None:
    trainer = pl.Trainer(**cfg.trainer)
    exp_manager(trainer, cfg.get("exp_manager", None))
    do_training = True
    if not cfg.pretrained_model:
        logging.info(f'Config: {OmegaConf.to_yaml(cfg)}')
        model = PunctuationCapitalizationModel(cfg.model, trainer=trainer)
    else:
        logging.info(f'Loading pretrained model {cfg.pretrained_model}')
        model = PunctuationCapitalizationModel.from_pretrained(cfg.pretrained_model)
        data_dir = cfg.model.dataset.get('data_dir', None)
        if data_dir:
            model.update_data_dir(data_dir)
            model.setup_training_data()
            model.setup_validation_data()
            logging.info(f'Using config file of the pretrained model')
        else:
            do_training = False
            logging.info(
                f'Data dir should be specified for training/finetuning. '
                f'Using pretrained {cfg.pretrained_model} model weights and skipping finetuning.'
            )

    if do_training:
        trainer.fit(model)
        if cfg.model.nemo_path:
            model.save_to(cfg.model.nemo_path)

    logging.info(
        'During evaluation/testing, it is currently advisable to construct a new Trainer with single GPU '
        'and no DDP to obtain accurate results'
    )
    gpu = 1 if cfg.trainer.gpus != 0 else 0
    trainer = pl.Trainer(gpus=gpu)
    model.set_trainer(trainer)

    # run an inference on a few examples
    queries = [
        'we bought four shirts one pen and a mug from the nvidia gear store in santa clara',
        'what can i do for you today',
        'how are you',
    ]
    inference_results = model.add_punctuation_capitalization(queries)

    for query, result in zip(queries, inference_results):
        logging.info(f'Query : {query}')
        logging.info(f'Result: {result.strip()}\n')