示例#1
0
def main(cfg: ParallelAlignmentConfig):
    if cfg.model.endswith(".nemo"):
        logging.info("Attempting to initialize from .nemo file")
        model = ASRModel.restore_from(restore_path=cfg.model, map_location="cpu")
    elif cfg.model.endswith(".ckpt"):
        logging.info("Attempting to initialize from .ckpt file")
        model = ASRModel.load_from_checkpoint(checkpoint_path=cfg.model, map_location="cpu")
    else:
        logging.info(
            "Attempting to initialize from a pretrained model as the model name does not have the extension of .nemo or .ckpt"
        )
        model = ASRModel.from_pretrained(model_name=cfg.model, map_location="cpu")

    trainer = ptl.Trainer(**cfg.trainer)

    cfg.predict_ds.return_sample_id = True
    cfg.return_predictions = False
    cfg.use_cer = False
    cfg.predict_ds = match_train_config(predict_ds=cfg.predict_ds, train_ds=model._cfg.train_ds)
    data_loader = model._setup_dataloader_from_config(cfg.predict_ds)

    os.makedirs(cfg.output_path, exist_ok=True)
    # trainer.global_rank is not valid before predict() is called. Need this hack to find the correct global_rank.
    global_rank = trainer.node_rank * trainer.num_devices + int(os.environ.get("LOCAL_RANK", 0))
    output_file = os.path.join(cfg.output_path, f"predictions_{global_rank}.json")
    output_ctm_dir = os.path.join(cfg.output_path, "ctm")
    predictor_writer = ASRCTMPredictionWriter(
        dataset=data_loader.dataset,
        output_file=output_file,
        output_ctm_dir=output_ctm_dir,
        time_per_frame=cfg.model_stride * model._cfg.preprocessor['window_stride'],
    )
    trainer.callbacks.extend([predictor_writer])

    aligner_wrapper = AlignerWrapperModel(model=model, cfg=cfg.aligner_args)
    trainer.predict(model=aligner_wrapper, dataloaders=data_loader, return_predictions=cfg.return_predictions)
    samples_num = predictor_writer.close_output_file()

    logging.info(
        f"Prediction on rank {global_rank} is done for {samples_num} samples and results are stored in {output_file}."
    )

    if torch.distributed.is_initialized():
        torch.distributed.barrier()

    samples_num = 0
    if is_global_rank_zero():
        output_file = os.path.join(cfg.output_path, f"predictions_all.json")
        logging.info(f"Prediction files are being aggregated in {output_file}.")
        with open(output_file, 'tw', encoding="utf-8") as outf:
            for rank in range(trainer.world_size):
                input_file = os.path.join(cfg.output_path, f"predictions_{rank}.json")
                with open(input_file, 'r', encoding="utf-8") as inpf:
                    lines = inpf.readlines()
                    samples_num += len(lines)
                    outf.writelines(lines)
        logging.info(
            f"Prediction is done for {samples_num} samples in total on all workers and results are aggregated in {output_file}."
        )
def main(cfg: ParallelTranscriptionConfig):
    if cfg.model.endswith(".nemo"):
        logging.info("Attempting to initialize from .nemo file")
        model = ASRModel.restore_from(restore_path=cfg.model,
                                      map_location="cpu")
    elif cfg.model.endswith(".ckpt"):
        logging.info("Attempting to initialize from .ckpt file")
        model = ASRModel.load_from_checkpoint(checkpoint_path=cfg.model,
                                              map_location="cpu")
    else:
        logging.info(
            "Attempting to initialize from a pretrained model as the model name does not have the extension of .nemo or .ckpt"
        )
        model = ASRModel.from_pretrained(model_name=cfg.model,
                                         map_location="cpu")

    trainer = ptl.Trainer(**cfg.trainer)

    cfg.predict_ds.return_sample_id = True
    cfg.predict_ds = match_train_config(predict_ds=cfg.predict_ds,
                                        train_ds=model.cfg.train_ds)
    data_loader = model._setup_dataloader_from_config(cfg.predict_ds)

    os.makedirs(cfg.output_path, exist_ok=True)
    # trainer.global_rank is not valid before predict() is called. Need this hack to find the correct global_rank.
    global_rank = trainer.node_rank * trainer.num_gpus + int(
        os.environ.get("LOCAL_RANK", 0))
    output_file = os.path.join(cfg.output_path,
                               f"predictions_{global_rank}.json")
    predictor_writer = ASRPredictionWriter(dataset=data_loader.dataset,
                                           output_file=output_file)
    trainer.callbacks.extend([predictor_writer])

    predictions = trainer.predict(model=model,
                                  dataloaders=data_loader,
                                  return_predictions=cfg.return_predictions)
    if predictions is not None:
        predictions = list(itertools.chain.from_iterable(predictions))
    samples_num = predictor_writer.close_output_file()

    logging.info(
        f"Prediction on rank {global_rank} is done for {samples_num} samples and results are stored in {output_file}."
    )

    if torch.distributed.is_initialized():
        torch.distributed.barrier()

    samples_num = 0
    pred_text_list = []
    text_list = []
    if is_global_rank_zero():
        output_file = os.path.join(cfg.output_path, f"predictions_all.json")
        logging.info(
            f"Prediction files are being aggregated in {output_file}.")
        with open(output_file, 'w') as outf:
            for rank in range(trainer.world_size):
                input_file = os.path.join(cfg.output_path,
                                          f"predictions_{rank}.json")
                with open(input_file, 'r') as inpf:
                    lines = inpf.readlines()
                    for line in lines:
                        item = json.loads(line)
                        pred_text_list.append(item["pred_text"])
                        text_list.append(item["text"])
                        outf.write(json.dumps(item) + "\n")
                        samples_num += 1
        wer_cer = word_error_rate(hypotheses=pred_text_list,
                                  references=text_list,
                                  use_cer=cfg.use_cer)
        logging.info(
            f"Prediction is done for {samples_num} samples in total on all workers and results are aggregated in {output_file}."
        )
        logging.info("{} for all predictions is {:.4f}.".format(
            "CER" if cfg.use_cer else "WER", wer_cer))