def main(cfg: ParallelAlignmentConfig): if cfg.model.endswith(".nemo"): logging.info("Attempting to initialize from .nemo file") model = ASRModel.restore_from(restore_path=cfg.model, map_location="cpu") elif cfg.model.endswith(".ckpt"): logging.info("Attempting to initialize from .ckpt file") model = ASRModel.load_from_checkpoint(checkpoint_path=cfg.model, map_location="cpu") else: logging.info( "Attempting to initialize from a pretrained model as the model name does not have the extension of .nemo or .ckpt" ) model = ASRModel.from_pretrained(model_name=cfg.model, map_location="cpu") trainer = ptl.Trainer(**cfg.trainer) cfg.predict_ds.return_sample_id = True cfg.return_predictions = False cfg.use_cer = False cfg.predict_ds = match_train_config(predict_ds=cfg.predict_ds, train_ds=model._cfg.train_ds) data_loader = model._setup_dataloader_from_config(cfg.predict_ds) os.makedirs(cfg.output_path, exist_ok=True) # trainer.global_rank is not valid before predict() is called. Need this hack to find the correct global_rank. global_rank = trainer.node_rank * trainer.num_devices + int(os.environ.get("LOCAL_RANK", 0)) output_file = os.path.join(cfg.output_path, f"predictions_{global_rank}.json") output_ctm_dir = os.path.join(cfg.output_path, "ctm") predictor_writer = ASRCTMPredictionWriter( dataset=data_loader.dataset, output_file=output_file, output_ctm_dir=output_ctm_dir, time_per_frame=cfg.model_stride * model._cfg.preprocessor['window_stride'], ) trainer.callbacks.extend([predictor_writer]) aligner_wrapper = AlignerWrapperModel(model=model, cfg=cfg.aligner_args) trainer.predict(model=aligner_wrapper, dataloaders=data_loader, return_predictions=cfg.return_predictions) samples_num = predictor_writer.close_output_file() logging.info( f"Prediction on rank {global_rank} is done for {samples_num} samples and results are stored in {output_file}." ) if torch.distributed.is_initialized(): torch.distributed.barrier() samples_num = 0 if is_global_rank_zero(): output_file = os.path.join(cfg.output_path, f"predictions_all.json") logging.info(f"Prediction files are being aggregated in {output_file}.") with open(output_file, 'tw', encoding="utf-8") as outf: for rank in range(trainer.world_size): input_file = os.path.join(cfg.output_path, f"predictions_{rank}.json") with open(input_file, 'r', encoding="utf-8") as inpf: lines = inpf.readlines() samples_num += len(lines) outf.writelines(lines) logging.info( f"Prediction is done for {samples_num} samples in total on all workers and results are aggregated in {output_file}." )
def main(cfg: ParallelTranscriptionConfig): if cfg.model.endswith(".nemo"): logging.info("Attempting to initialize from .nemo file") model = ASRModel.restore_from(restore_path=cfg.model, map_location="cpu") elif cfg.model.endswith(".ckpt"): logging.info("Attempting to initialize from .ckpt file") model = ASRModel.load_from_checkpoint(checkpoint_path=cfg.model, map_location="cpu") else: logging.info( "Attempting to initialize from a pretrained model as the model name does not have the extension of .nemo or .ckpt" ) model = ASRModel.from_pretrained(model_name=cfg.model, map_location="cpu") trainer = ptl.Trainer(**cfg.trainer) cfg.predict_ds.return_sample_id = True cfg.predict_ds = match_train_config(predict_ds=cfg.predict_ds, train_ds=model.cfg.train_ds) data_loader = model._setup_dataloader_from_config(cfg.predict_ds) os.makedirs(cfg.output_path, exist_ok=True) # trainer.global_rank is not valid before predict() is called. Need this hack to find the correct global_rank. global_rank = trainer.node_rank * trainer.num_gpus + int( os.environ.get("LOCAL_RANK", 0)) output_file = os.path.join(cfg.output_path, f"predictions_{global_rank}.json") predictor_writer = ASRPredictionWriter(dataset=data_loader.dataset, output_file=output_file) trainer.callbacks.extend([predictor_writer]) predictions = trainer.predict(model=model, dataloaders=data_loader, return_predictions=cfg.return_predictions) if predictions is not None: predictions = list(itertools.chain.from_iterable(predictions)) samples_num = predictor_writer.close_output_file() logging.info( f"Prediction on rank {global_rank} is done for {samples_num} samples and results are stored in {output_file}." ) if torch.distributed.is_initialized(): torch.distributed.barrier() samples_num = 0 pred_text_list = [] text_list = [] if is_global_rank_zero(): output_file = os.path.join(cfg.output_path, f"predictions_all.json") logging.info( f"Prediction files are being aggregated in {output_file}.") with open(output_file, 'w') as outf: for rank in range(trainer.world_size): input_file = os.path.join(cfg.output_path, f"predictions_{rank}.json") with open(input_file, 'r') as inpf: lines = inpf.readlines() for line in lines: item = json.loads(line) pred_text_list.append(item["pred_text"]) text_list.append(item["text"]) outf.write(json.dumps(item) + "\n") samples_num += 1 wer_cer = word_error_rate(hypotheses=pred_text_list, references=text_list, use_cer=cfg.use_cer) logging.info( f"Prediction is done for {samples_num} samples in total on all workers and results are aggregated in {output_file}." ) logging.info("{} for all predictions is {:.4f}.".format( "CER" if cfg.use_cer else "WER", wer_cer))