def main(cfg: DictConfig) -> None: logging.info(f'Config Params:\n {OmegaConf.to_yaml(cfg)}') trainer = pl.Trainer(**cfg.trainer) log_dir = exp_manager(trainer, cfg.get("exp_manager", None)) # initialize the model using the config file model = IntentSlotClassificationModel(cfg.model, trainer=trainer) # training logging.info( "================================================================================================" ) logging.info('Starting training...') trainer.fit(model) logging.info('Training finished!') # Stop further testing as fast_dev_run does not save checkpoints if trainer.fast_dev_run: return # after model training is done, you can load the model from the saved checkpoint # and evaluate it on a data file or on given queries. logging.info( "================================================================================================" ) logging.info("Starting the testing of the trained model on test set...") logging.info( "We will load the latest model saved checkpoint from the training...") # load the model from the saved .nemo checkpoint checkpoint_path = str(log_dir) + '/checkpoints/IntentSlot.nemo' eval_model = IntentSlotClassificationModel.restore_from( restore_path=checkpoint_path) # we will setup testing data reusing the same config (test section) eval_model.update_data_dir_for_testing(data_dir=cfg.model.data_dir) eval_model.setup_test_data(test_data_config=cfg.model.test_ds) trainer.test(model=eval_model, ckpt_path=None, verbose=False) logging.info("Testing finished!") # run an inference on a few examples logging.info( "======================================================================================" ) logging.info("Evaluate the model on the given queries...") # this will work well if you train the model on Assistant dataset # for your own dataset change the examples appropriately queries = [ 'set alarm for seven thirty am', 'lower volume by fifty percent', 'what is my schedule for tomorrow', ] pred_intents, pred_slots = eval_model.predict_from_examples( queries, cfg.model.test_ds) logging.info( 'The prediction results of some sample queries with the trained model:' ) for query, intent, slots in zip(queries, pred_intents, pred_slots): logging.info(f'Query : {query}') logging.info(f'Predicted Intent: {intent}') logging.info(f'Predicted Slots: {slots}') logging.info("Inference finished!")