def main(): args = get_args() if args.pretrained_name is None: model = PunctuationCapitalizationModel.restore_from(args.model_path) else: model = PunctuationCapitalizationModel.from_pretrained( args.pretrained_name) if args.input_manifest is None: texts = [] with args.input_text.open() as f: texts.append(f.readline().strip()) else: manifest = load_manifest(args.input_manifest) text_key = "pred_text" if "pred_text" in manifest[0] else "text" texts = [] for item in manifest: texts.append(item[text_key]) processed_texts = model.add_punctuation_capitalization( texts, batch_size=args.batch_size, max_seq_length=args.max_seq_length, step=args.step, margin=args.margin, ) if args.output_manifest is None: with args.output_text.open('w') as f: for t in processed_texts: f.write(t + '\n') else: with args.output_manifest.open('w') as f: for item, t in zip(manifest, processed_texts): item[text_key] = t f.write(json.dumps(item) + '\n')
def main(cfg: DictConfig) -> None: torch.manual_seed(42) cfg = OmegaConf.merge( OmegaConf.structured(PunctuationCapitalizationConfig()), cfg) trainer = pl.Trainer(**cfg.trainer) exp_manager(trainer, cfg.get("exp_manager", None)) if not cfg.do_training and not cfg.do_testing: raise ValueError( "At least one of config parameters `do_training` and `do_testing` has to `true`." ) if cfg.do_training: if cfg.model.get('train_ds') is None: raise ValueError( '`model.train_ds` config section is required if `do_training` config item is `True`.' ) if cfg.do_testing: if cfg.model.get('test_ds') is None: raise ValueError( '`model.test_ds` config section is required if `do_testing` config item is `True`.' ) if not cfg.pretrained_model: logging.info(f'Config: {OmegaConf.to_yaml(cfg)}') model = PunctuationCapitalizationModel(cfg.model, trainer=trainer) else: if os.path.exists(cfg.pretrained_model): model = PunctuationCapitalizationModel.restore_from( cfg.pretrained_model) elif cfg.pretrained_model in PunctuationCapitalizationModel.get_available_model_names( ): model = PunctuationCapitalizationModel.from_pretrained( cfg.pretrained_model) else: raise ValueError( f'Provide path to the pre-trained .nemo file or choose from ' f'{PunctuationCapitalizationModel.list_available_models()}') model.update_config_after_restoring_from_checkpoint( class_labels=cfg.model.class_labels, common_dataset_parameters=cfg.model.common_dataset_parameters, train_ds=cfg.model.get('train_ds') if cfg.do_training else None, validation_ds=cfg.model.get('validation_ds') if cfg.do_training else None, test_ds=cfg.model.get('test_ds') if cfg.do_testing else None, optim=cfg.model.get('optim') if cfg.do_training else None, ) model.set_trainer(trainer) if cfg.do_training: model.setup_training_data() model.setup_validation_data() model.setup_optimization() else: model.setup_test_data() if cfg.do_training: trainer.fit(model) if cfg.do_testing: trainer.test(model)
def main(): args = get_args() if args.pretrained_name is None: model = PunctuationCapitalizationModel.restore_from(args.model_path) else: model = PunctuationCapitalizationModel.from_pretrained( args.pretrained_name) if args.device is None: if torch.cuda.is_available(): model = model.cuda() else: model = model.cpu() else: model = model.to(args.device) model = model.cpu() if args.input_manifest is None: texts = [] with args.input_text.open() as f: for line in f: texts.append(line.strip()) else: manifest = load_manifest(args.input_manifest) text_key = "pred_text" if "pred_text" in manifest[0] else "text" texts = [] for item in manifest: texts.append(item[text_key]) processed_texts = model.add_punctuation_capitalization( texts, batch_size=args.batch_size, max_seq_length=args.max_seq_length, step=args.step, margin=args.margin, return_labels=args.save_labels_instead_of_text, ) if args.output_manifest is None: args.output_text.parent.mkdir(exist_ok=True, parents=True) with args.output_text.open('w') as f: for t in processed_texts: f.write(t + '\n') else: args.output_manifest.parent.mkdir(exist_ok=True, parents=True) with args.output_manifest.open('w') as f: for item, t in zip(manifest, processed_texts): item[text_key] = t f.write(json.dumps(item) + '\n')
def main(cfg: DictConfig) -> None: trainer = pl.Trainer(**cfg.trainer) exp_manager(trainer, cfg.get("exp_manager", None)) if not cfg.pretrained_model: logging.info(f'Config: {OmegaConf.to_yaml(cfg)}') model = PunctuationCapitalizationModel(cfg.model, trainer=trainer) else: if os.path.exists(cfg.pretrained_model): model = PunctuationCapitalizationModel.restore_from( cfg.pretrained_model) elif cfg.pretrained_model in PunctuationCapitalizationModel.get_available_model_names( ): model = PunctuationCapitalizationModel.from_pretrained( cfg.pretrained_model) else: raise ValueError( f'Provide path to the pre-trained .nemo file or choose from {PunctuationCapitalizationModel.list_available_models()}' ) data_dir = cfg.model.dataset.get('data_dir', None) if data_dir: if not os.path.exists(data_dir): raise ValueError(f'{data_dir} is not found at') # we can also do finetuning of the pretrained model but we would need to update the data dir model.update_data_dir(data_dir) # setup train and validation Pytorch DataLoaders model.setup_training_data() model.setup_validation_data() logging.info(f'Using config file of the pretrained model') else: raise ValueError( 'Specify a valid dataset directory that contains test_ds.text_file and test_ds.labels_file \ with "model.dataset.data_dir" argument') trainer.fit(model) if cfg.model.nemo_path: model.save_to(cfg.model.nemo_path) logging.info(f'The model was saved to {cfg.model.nemo_path}')
def main(cfg: DictConfig) -> None: logging.info( 'During evaluation/testing, it is currently advisable to construct a new Trainer with single GPU and \ no DDP to obtain accurate results') if not hasattr(cfg.model, 'test_ds'): raise ValueError( f'model.test_ds was not found in the config, skipping evaluation') else: gpu = 1 if cfg.trainer.gpus != 0 else 0 trainer = pl.Trainer( gpus=gpu, precision=cfg.trainer.precision, amp_level=cfg.trainer.amp_level, logger=False, checkpoint_callback=False, ) exp_dir = exp_manager(trainer, cfg.exp_manager) if not cfg.pretrained_model: raise ValueError( 'To run evaluation and inference script a pre-trained model or .nemo file must be provided.' f'Choose from {PunctuationCapitalizationModel.list_available_models()} or "pretrained_model"="your_model.nemo"' ) if os.path.exists(cfg.pretrained_model): model = PunctuationCapitalizationModel.restore_from( cfg.pretrained_model) elif cfg.pretrained_model in PunctuationCapitalizationModel.get_available_model_names( ): model = PunctuationCapitalizationModel.from_pretrained( cfg.pretrained_model) else: raise ValueError( f'Provide path to the pre-trained .nemo file or choose from {PunctuationCapitalizationModel.list_available_models()}' ) data_dir = cfg.model.dataset.get('data_dir', None) if data_dir is None: logging.error( 'No dataset directory provided. Skipping evaluation. ' 'To run evaluation on a file, specify path to the directory that contains test_ds.text_file and test_ds.labels_file with "model.dataset.data_dir" argument.' ) elif not os.path.exists(data_dir): logging.error( f'{data_dir} is not found, skipping evaluation on the test set.') else: model.update_data_dir(data_dir=data_dir) model._cfg.dataset = cfg.model.dataset if not hasattr(cfg.model, 'test_ds'): logging.error( f'model.test_ds was not found in the config, skipping evaluation' ) elif model.prepare_test(trainer): model.setup_test_data(cfg.model.test_ds) trainer.test(model) else: logging.error( 'Skipping the evaluation. The trainer is not setup properly.') # run an inference on a few examples queries = [ 'we bought four shirts one pen and a mug from the nvidia gear store in santa clara', 'what can i do for you today', 'how are you', ] inference_results = model.add_punctuation_capitalization( queries, batch_size=len(queries), max_seq_length=512) for query, result in zip(queries, inference_results): logging.info(f'Query : {query}') logging.info(f'Result: {result.strip()}\n') logging.info(f'Results are saved at {exp_dir}')