Exemplo n.º 1
0
def main(cfg: DictConfig) -> None:
    logging.info(f'Config Params:\n {OmegaConf.to_yaml(cfg)}')
    trainer = pl.Trainer(**cfg.trainer)
    log_dir = exp_manager(trainer, cfg.get("exp_manager", None))

    # initialize the model using the config file
    model = IntentSlotClassificationModel(cfg.model, trainer=trainer)

    # training
    logging.info(
        "================================================================================================"
    )
    logging.info('Starting training...')
    trainer.fit(model)
    logging.info('Training finished!')

    # Stop further testing as fast_dev_run does not save checkpoints
    if trainer.fast_dev_run:
        return

    # after model training is done, you can load the model from the saved checkpoint
    # and evaluate it on a data file or on given queries.
    logging.info(
        "================================================================================================"
    )
    logging.info("Starting the testing of the trained model on test set...")
    logging.info(
        "We will load the latest model saved checkpoint from the training...")

    # load the model from the saved .nemo checkpoint
    checkpoint_path = str(log_dir) + '/checkpoints/IntentSlot.nemo'
    eval_model = IntentSlotClassificationModel.restore_from(
        restore_path=checkpoint_path)

    # we will setup testing data reusing the same config (test section)
    eval_model.update_data_dir_for_testing(data_dir=cfg.model.data_dir)
    eval_model.setup_test_data(test_data_config=cfg.model.test_ds)

    trainer.test(model=eval_model, ckpt_path=None, verbose=False)
    logging.info("Testing finished!")

    # run an inference on a few examples
    logging.info(
        "======================================================================================"
    )
    logging.info("Evaluate the model on the given queries...")

    # this will work well if you train the model on Assistant dataset
    # for your own dataset change the examples appropriately
    queries = [
        'set alarm for seven thirty am',
        'lower volume by fifty percent',
        'what is my schedule for tomorrow',
    ]

    pred_intents, pred_slots = eval_model.predict_from_examples(
        queries, cfg.model.test_ds)

    logging.info(
        'The prediction results of some sample queries with the trained model:'
    )
    for query, intent, slots in zip(queries, pred_intents, pred_slots):
        logging.info(f'Query : {query}')
        logging.info(f'Predicted Intent: {intent}')
        logging.info(f'Predicted Slots: {slots}')

    logging.info("Inference finished!")