示例#1
0
文件: cut_audio.py 项目: manneh/NeMo
def add_transcript_to_manifest(manifest_original: str, manifest_updated: str,
                               asr_model: nemo_asr.models.EncDecCTCModel,
                               batch_size: int) -> None:
    """
    Adds transcripts generated by the asr_model to the manifest_original.

    Args:
        manifest_original: path to the manifest
        manifest_updated: path to the updated manifest with transcript included
        asr_model: CTC-based ASR model, for example, QuartzNet15x5Base-En
        batch_size: Batch size for asr_model inference
    """
    transcripts = get_transcript(manifest_original, asr_model, batch_size)
    with open(manifest_original, 'r', encoding='utf8') as f:
        with open(manifest_updated, 'w', encoding='utf8') as f_updated:
            for i, line in enumerate(f):
                info = json.loads(line)
                info['pred_text'] = transcripts[i].strip()
                info['WER'] = round(
                    word_error_rate([info['pred_text']], [info['text']]) * 100,
                    2)
                info['CER'] = round(
                    word_error_rate([info['pred_text']], [info['text']],
                                    use_cer=True) * 100, 2)
                json.dump(info, f_updated, ensure_ascii=False)
                f_updated.write('\n')
示例#2
0
    def test_wer_metric_randomized(self, test_wer_bpe):
        """This test relies on correctness of word_error_rate function."""
        def __random_string(length):
            return ''.join(
                random.choice(''.join(self.vocabulary)) for _ in range(length))

        if test_wer_bpe:
            wer = WERBPE(deepcopy(self.char_tokenizer),
                         batch_dim_index=0,
                         use_cer=False,
                         ctc_decode=True)
        else:
            wer = WER(vocabulary=self.vocabulary,
                      batch_dim_index=0,
                      use_cer=False,
                      ctc_decode=True)

        for test_id in range(256):
            n1 = random.randint(1, 512)
            n2 = random.randint(1, 512)
            s1 = __random_string(n1)
            s2 = __random_string(n2)
            # skip empty strings as reference
            if s2.strip():
                assert (abs(
                    self.get_wer(wer,
                                 prediction=s1,
                                 reference=s2,
                                 use_tokenizer=test_wer_bpe) -
                    word_error_rate(hypotheses=[s1], references=[s2])) < 1e-6)
示例#3
0
def get_wer_feat(mfst, asr, tokens_per_chunk, delay, model_stride_in_secs,
                 batch_size):
    hyps = []
    refs = []
    audio_filepaths = []

    with open(mfst, "r") as mfst_f:
        print("Parsing manifest files...")
        for l in mfst_f:
            row = json.loads(l.strip())
            audio_filepaths.append(row['audio_filepath'])
            refs.append(row['text'])

    with torch.inference_mode():
        with torch.cuda.amp.autocast():
            batch = []
            asr.sample_offset = 0
            for idx in tqdm.tqdm(range(len(audio_filepaths)),
                                 desc='Sample:',
                                 total=len(audio_filepaths)):
                batch.append((audio_filepaths[idx], refs[idx]))

                if len(batch) == batch_size:
                    audio_files = [sample[0] for sample in batch]

                    asr.reset()
                    asr.read_audio_file(audio_files, delay,
                                        model_stride_in_secs)
                    hyp_list = asr.transcribe(tokens_per_chunk, delay)
                    hyps.extend(hyp_list)

                    batch.clear()
                    asr.sample_offset += batch_size

            if len(batch) > 0:
                asr.batch_size = len(batch)
                asr.frame_bufferer.batch_size = len(batch)
                asr.reset()

                audio_files = [sample[0] for sample in batch]
                asr.read_audio_file(audio_files, delay, model_stride_in_secs)
                hyp_list = asr.transcribe(tokens_per_chunk, delay)
                hyps.extend(hyp_list)

                batch.clear()
                asr.sample_offset += len(batch)

    if os.environ.get('DEBUG', '0') in ('1', 'y', 't'):
        for hyp, ref in zip(hyps, refs):
            print("hyp:", hyp)
            print("ref:", ref)

    wer = word_error_rate(hypotheses=hyps, references=refs)
    return hyps, refs, wer
示例#4
0
 def test_wer_function(self):
     assert word_error_rate(hypotheses=['cat'], references=['cot']) == 1.0
     assert word_error_rate(hypotheses=['GPU'], references=['G P U']) == 1.0
     assert word_error_rate(hypotheses=['G P U'], references=['GPU']) == 3.0
     assert word_error_rate(hypotheses=['ducati motorcycle'], references=['motorcycle']) == 1.0
     assert word_error_rate(hypotheses=['ducati motorcycle'], references=['ducuti motorcycle']) == 0.5
     assert word_error_rate(hypotheses=['a B c'], references=['a b c']) == 1.0 / 3.0
示例#5
0
def main():
    parser = ArgumentParser()
    parser.add_argument(
        "--asr_model", type=str, default="QuartzNet15x5Base-En", required=True, help="Pass: '******'",
    )
    parser.add_argument("--dataset", type=str, required=True, help="path to evaluation data")
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--wer_tolerance", type=float, default=1.0, help="used by test")
    parser.add_argument(
        "--normalize_text", default=True, type=bool, help="Normalize transcripts or not. Set to False for non-English."
    )
    args = parser.parse_args()
    torch.set_grad_enabled(False)

    if args.asr_model.endswith('.nemo'):
        logging.info(f"Using local ASR model from {args.asr_model}")
        asr_model = EncDecCTCModel.restore_from(restore_path=args.asr_model)
    else:
        logging.info(f"Using NGC cloud ASR model {args.asr_model}")
        asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model)
    asr_model.setup_test_data(
        test_data_config={
            'sample_rate': 16000,
            'manifest_filepath': args.dataset,
            'labels': asr_model.decoder.vocabulary,
            'batch_size': args.batch_size,
            'normalize_transcripts': args.normalize_text,
        }
    )
    if can_gpu:
        asr_model = asr_model.cuda()
    asr_model.eval()
    labels_map = dict([(i, asr_model.decoder.vocabulary[i]) for i in range(len(asr_model.decoder.vocabulary))])
    wer = WER(vocabulary=asr_model.decoder.vocabulary)
    hypotheses = []
    references = []
    for test_batch in asr_model.test_dataloader():
        if can_gpu:
            test_batch = [x.cuda() for x in test_batch]
        with autocast():
            log_probs, encoded_len, greedy_predictions = asr_model(
                input_signal=test_batch[0], input_signal_length=test_batch[1]
            )
        hypotheses += wer.ctc_decoder_predictions_tensor(greedy_predictions)
        for batch_ind in range(greedy_predictions.shape[0]):
            reference = ''.join([labels_map[c] for c in test_batch[2][batch_ind].cpu().detach().numpy()])
            references.append(reference)
        del test_batch
    wer_value = word_error_rate(hypotheses=hypotheses, references=references)
    if wer_value > args.wer_tolerance:
        raise ValueError(f"Got WER of {wer_value}. It was higher than {args.wer_tolerance}")
    logging.info(f'Got WER of {wer_value}. Tolerance was {args.wer_tolerance}')
示例#6
0
def evaluate(asr_model, asr_onnx, labels_map, wer, qat):
    # Eval the model
    hypotheses = []
    references = []
    stream = cuda.Stream()
    vocabulary_size = len(labels_map) + 1
    engine_file_path = build_trt_engine(asr_model, asr_onnx, qat)
    with open(engine_file_path, 'rb') as f, trt.Runtime(TRT_LOGGER) as runtime:
        trt_engine = runtime.deserialize_cuda_engine(f.read())
        trt_ctx = trt_engine.create_execution_context()

        profile_shape = trt_engine.get_profile_shape(profile_index=0,
                                                     binding=0)
        print("profile shape min:{}, opt:{}, max:{}".format(
            profile_shape[0], profile_shape[1], profile_shape[2]))
        max_input_shape = profile_shape[2]
        input_nbytes = trt.volume(max_input_shape) * trt.float32.itemsize
        d_input = cuda.mem_alloc(input_nbytes)
        max_output_shape = [
            max_input_shape[0], vocabulary_size, (max_input_shape[-1] + 1) // 2
        ]
        output_nbytes = trt.volume(max_output_shape) * trt.float32.itemsize
        d_output = cuda.mem_alloc(output_nbytes)

        for test_batch in asr_model.test_dataloader():
            if can_gpu:
                test_batch = [x.cuda() for x in test_batch]
            processed_signal, processed_signal_length = asr_model.preprocessor(
                input_signal=test_batch[0], length=test_batch[1])

            greedy_predictions = trt_inference(
                stream,
                trt_ctx,
                d_input,
                d_output,
                input_signal=processed_signal,
                input_signal_length=processed_signal_length,
            )
            hypotheses += wer.ctc_decoder_predictions_tensor(
                greedy_predictions)
            for batch_ind in range(greedy_predictions.shape[0]):
                seq_len = test_batch[3][batch_ind].cpu().detach().numpy()
                seq_ids = test_batch[2][batch_ind].cpu().detach().numpy()
                reference = ''.join(
                    [labels_map[c] for c in seq_ids[0:seq_len]])
                references.append(reference)
            del test_batch
        wer_value = word_error_rate(hypotheses=hypotheses,
                                    references=references,
                                    use_cer=wer.use_cer)

    return wer_value
示例#7
0
    def test_wer_metric_randomized(self):
        """This test relies on correctness of word_error_rate function."""

        def __randomString(N):
            return ''.join(random.choice(''.join(self.vocabulary)) for i in range(N))

        wer = WER(vocabulary=self.vocabulary, batch_dim_index=0, use_cer=False, ctc_decode=True)

        for test_id in range(256):
            n1 = random.randint(1, 512)
            n2 = random.randint(1, 512)
            s1 = __randomString(n1)
            s2 = __randomString(n2)
            # Floating-point math doesn't seem to be an issue here. Leaving as ==
            assert self.get_wer(wer, prediction=s1, reference=s2) == word_error_rate(hypotheses=[s1], references=[s2])
def evaluate(asr_model, labels_map, wer):
    # Eval the model
    hypotheses = []
    references = []
    for test_batch in asr_model.test_dataloader():
        if can_gpu:
            test_batch = [x.cuda() for x in test_batch]
        with autocast():
            log_probs, encoded_len, greedy_predictions = asr_model(
                input_signal=test_batch[0], input_signal_length=test_batch[1]
            )
        hypotheses += wer.ctc_decoder_predictions_tensor(greedy_predictions)
        for batch_ind in range(greedy_predictions.shape[0]):
            reference = ''.join([labels_map[c] for c in test_batch[2][batch_ind].cpu().detach().numpy()])
            references.append(reference)
        del test_batch
    wer_value = word_error_rate(hypotheses=hypotheses, references=references)

    return wer_value
示例#9
0
def calculate_cer(normalized_texts: List[str], transcript: str, remove_punct=False) -> List[Tuple[str, float]]:
    """
    Calculates character error rate (CER)

    Args:
        normalized_texts: normalized text options
        transcript: ASR model output

    Returns: normalized options with corresponding CER
    """
    normalized_options = []
    for text in normalized_texts:
        text_clean = text.replace('-', ' ').lower()
        if remove_punct:
            for punct in "!?:;,.-()*+-/<=>@^_":
                text_clean = text_clean.replace(punct, "")
        cer = round(word_error_rate([transcript], [text_clean], use_cer=True) * 100, 2)
        normalized_options.append((text, cer))
    return normalized_options
示例#10
0
    def test_rnnt_wer_metric_randomized(self, test_wer_bpe):
        """This test relies on correctness of word_error_rate function."""
        def __random_string(length):
            return ''.join(
                random.choice(''.join(self.vocabulary)) for _ in range(length))

        for test_id in range(256):
            n1 = random.randint(1, 512)
            n2 = random.randint(1, 512)
            s1 = __random_string(n1)
            s2 = __random_string(n2)
            # skip empty strings as reference
            if s2.strip():
                assert (abs(
                    self.get_wer_rnnt(prediction=s1,
                                      reference=s2,
                                      batch_dim_index=0,
                                      test_wer_bpe=test_wer_bpe) -
                    word_error_rate(hypotheses=[s1], references=[s2])) < 1e-6)
示例#11
0
    def get_wder_dict_values(self, asr_eval_dict, wder_dict, count_dict, align_error_list):
        """
        Calculate the total error rates for WDER, WER and alignment error.
        """
        if '-' in asr_eval_dict['references_list'] or None in asr_eval_dict['references_list']:
            wer = -1
        else:
            wer = word_error_rate(
                hypotheses=asr_eval_dict['hypotheses_list'], references=asr_eval_dict['references_list']
            )

        wder_dict['total_WER'] = wer
        wder_dict['total_wder_rttm'] = 1 - (
            count_dict['grand_total_correct_word_count'] / count_dict['grand_total_pred_word_count']
        )

        if all(x for x in self.ctm_exists.values()) == True:
            wder_dict['total_wder_ctm_ref_trans'] = (
                count_dict['total_ctm_wder_count'] / count_dict['grand_total_ctm_word_count']
                if count_dict['grand_total_ctm_word_count'] > 0
                else -1
            )
            wder_dict['total_wder_ctm_pred_asr'] = (
                count_dict['total_ctm_wder_count'] / count_dict['grand_total_pred_word_count']
                if count_dict['grand_total_pred_word_count'] > 0
                else -1
            )
            wder_dict['total_diar_trans_acc'] = (
                count_dict['total_asr_and_spk_correct_words'] / count_dict['grand_total_ctm_word_count']
                if count_dict['grand_total_ctm_word_count'] > 0
                else -1
            )
            wder_dict['total_alignment_error_mean'] = (
                np.mean(self.align_error_list).round(4) if self.align_error_list != [] else -1
            )
            wder_dict['total_alignment_error_std'] = (
                np.std(self.align_error_list).round(4) if self.align_error_list != [] else -1
            )
        return wder_dict
示例#12
0
def batch_inference(args: argparse.Namespace):

    torch.set_grad_enabled(False)

    if args.asr_model.endswith(".nemo"):
        print(f"Using local ASR model from {args.asr_model}")
        asr_model = EncDecCTCModel.restore_from(restore_path=args.asr_model)
    else:
        print(f"Using NGC cloud ASR model {args.asr_model}")
        asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model)

    manifest = prepare_manifest(args.corpora_dir, args.limit)
    asr_model.setup_test_data(
        test_data_config={
            "sample_rate": 16000,
            "manifest_filepath": manifest,
            "labels": asr_model.decoder.vocabulary,
            "batch_size": args.batch_size,
            "normalize_transcripts": args.normalize_text,
        })

    refs_hyps = list(tqdm(generate_ref_hyps(asr_model, args.search,
                                            args.arpa)))
    references, hypotheses = [list(k) for k in zip(*refs_hyps)]

    os.makedirs(args.results_dir, exist_ok=True)
    data_io.write_lines(f"{args.results_dir}/refs.txt.gz", references)
    data_io.write_lines(f"{args.results_dir}/hyps.txt.gz", hypotheses)

    wer_value = word_error_rate(hypotheses=hypotheses, references=references)
    sys.stdout.flush()
    stats = {
        "wer": wer_value,
        "args": args.__dict__,
    }
    data_io.write_json(f"{args.results_dir}/stats.txt", stats)
    print(f"Got WER of {wer_value}")
    return stats
示例#13
0
def main():
    args = parse_arguments()

    # Instantiate pytorch model
    nemo_model = args.nemo_model
    nemo_model = ASRModel.restore_from(nemo_model,
                                       map_location='cpu')  # type: ASRModel
    nemo_model.freeze()

    if torch.cuda.is_available():
        nemo_model = nemo_model.to('cuda')

    export_model_if_required(args, nemo_model)

    # Instantiate RNNT Decoding loop
    encoder_model = args.onnx_encoder
    decoder_model = args.onnx_decoder
    max_symbols_per_step = args.max_symbold_per_step
    decoding = ONNXGreedyBatchedRNNTInfer(encoder_model, decoder_model,
                                          max_symbols_per_step)

    audio_filepath = resolve_audio_filepaths(args)

    # Evaluate Pytorch Model (CPU/GPU)
    actual_transcripts = nemo_model.transcribe(audio_filepath,
                                               batch_size=args.batch_size)[0]

    # Evaluate ONNX model (on CPU)
    with tempfile.TemporaryDirectory() as tmpdir:
        with open(os.path.join(tmpdir, 'manifest.json'), 'w') as fp:
            for audio_file in audio_filepath:
                entry = {
                    'audio_filepath': audio_file,
                    'duration': 100000,
                    'text': 'nothing'
                }
                fp.write(json.dumps(entry) + '\n')

        config = {
            'paths2audio_files': audio_filepath,
            'batch_size': args.batch_size,
            'temp_dir': tmpdir
        }

        # Push nemo model to CPU
        nemo_model = nemo_model.to('cpu')
        nemo_model.preprocessor.featurizer.dither = 0.0
        nemo_model.preprocessor.featurizer.pad_to = 0

        temporary_datalayer = nemo_model._setup_transcribe_dataloader(config)

        all_hypothesis = []
        for test_batch in tqdm(temporary_datalayer, desc="ONNX Transcribing"):
            input_signal, input_signal_length = test_batch[0], test_batch[1]

            # Acoustic features
            processed_audio, processed_audio_len = nemo_model.preprocessor(
                input_signal=input_signal, length=input_signal_length)
            # RNNT Decoding loop
            hypotheses = decoding(audio_signal=processed_audio,
                                  length=processed_audio_len)

            # Process hypothesis (map char/subword token ids to text)
            hypotheses = nemo_model.decoding.decode_hypothesis(
                hypotheses)  # type: List[str]

            # Extract text from the hypothesis
            texts = [h.text for h in hypotheses]

            all_hypothesis += texts
            del processed_audio, processed_audio_len
            del test_batch

    if args.log:
        for pt_transcript, onnx_transcript in zip(actual_transcripts,
                                                  all_hypothesis):
            print(f"Pytorch Transcripts : {pt_transcript}")
            print(f"ONNX Transcripts    : {onnx_transcript}")
        print()

    # Measure error rate between onnx and pytorch transcipts
    pt_onnx_cer = word_error_rate(all_hypothesis,
                                  actual_transcripts,
                                  use_cer=True)
    assert pt_onnx_cer < args.threshold, "Threshold violation !"

    print("Character error rate between Pytorch and ONNX :", pt_onnx_cer)
示例#14
0
def main(cfg: EvaluationConfig):
    torch.set_grad_enabled(False)

    if is_dataclass(cfg):
        cfg = OmegaConf.structured(cfg)

    if cfg.audio_dir is not None:
        raise RuntimeError(
            "Evaluation script requires ground truth labels to be passed via a manifest file. "
            "If manifest file is available, submit it via `dataset_manifest` argument."
        )

    if not os.path.exists(cfg.dataset_manifest):
        raise FileNotFoundError(
            f"The dataset manifest file could not be found at path : {cfg.dataset_manifest}"
        )

    if not cfg.only_score_manifest:
        # Transcribe speech into an output directory
        transcription_cfg = transcribe_speech.main(
            cfg)  # type: EvaluationConfig

        # Release GPU memory if it was used during transcription
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        logging.info(
            "Finished transcribing speech dataset. Computing ASR metrics..")

    else:
        cfg.output_filename = cfg.dataset_manifest
        transcription_cfg = cfg

    ground_truth_text = []
    predicted_text = []
    invalid_manifest = False
    with open(transcription_cfg.output_filename, 'r') as f:
        for line in f:
            data = json.loads(line)

            if 'pred_text' not in data:
                invalid_manifest = True
                break

            ground_truth_text.append(data['text'])
            predicted_text.append(data['pred_text'])

    # Test for invalid manifest supplied
    if invalid_manifest:
        raise ValueError(
            f"Invalid manifest provided: {transcription_cfg.output_filename} does not "
            f"contain value for `pred_text`.")

    # Compute the WER
    metric_name = 'CER' if cfg.use_cer else 'WER'
    metric_value = word_error_rate(hypotheses=predicted_text,
                                   references=ground_truth_text,
                                   use_cer=cfg.use_cer)

    if cfg.tolerance is not None:
        if metric_value > cfg.tolerance:
            raise ValueError(
                f"Got {metric_name} of {metric_value}, which was higher than tolerance={cfg.tolerance}"
            )

        logging.info(
            f'Got {metric_name} of {metric_value}. Tolerance was {cfg.tolerance}'
        )
    else:
        logging.info(f'Got {metric_name} of {metric_value}')

    # Inject the metric name and score into the config, and return the entire config
    with open_dict(cfg):
        cfg.metric_name = metric_name
        cfg.metric_value = metric_value

    return cfg
示例#15
0
def ASR_Grade(dataset, id, key):
    try:
        from torch.cuda.amp import autocast
    except ImportError:
        from contextlib import contextmanager

        @contextmanager
        def autocast(enabled=None):
            yield

    can_gpu = torch.cuda.is_available()

    parser = ArgumentParser()
    parser.add_argument(
        "--asr_model",
        type=str,
        default=model_Selected,
        required=True,
        help=f'Pass: {model_Selected}',
    )
    parser.add_argument("--dataset",
                        type=str,
                        required=True,
                        help="path to evaluation data")
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--wer_tolerance",
                        type=float,
                        default=1.0,
                        help="used by test")
    parser.add_argument(
        "--normalize_text",
        default=False,  # False <- we're using phonetic references
        type=bool,
        help="Normalize transcripts or not. Set to False for non-English.",
    )
    args = parser.parse_args(
        ["--dataset", dataset, "--asr_model", model_Selected])
    torch.set_grad_enabled(False)

    # Instantiate Jasper/QuartzNet models with the EncDecCTCModel class.
    asr_model = EncDecCTCModel.restore_from(model_Path)

    asr_model.setup_test_data(
        test_data_config={
            "sample_rate": 16000,
            "manifest_filepath": args.dataset,
            "labels": asr_model.decoder.vocabulary,
            "batch_size": args.batch_size,
            "normalize_transcripts": args.normalize_text,
        })
    if can_gpu:  # noqa
        asr_model = asr_model.cuda()
    asr_model.eval()
    labels_map = dict([(i, asr_model.decoder.vocabulary[i])
                       for i in range(len(asr_model.decoder.vocabulary))])
    wer = WER(vocabulary=asr_model.decoder.vocabulary)
    hypotheses = []
    references = []
    for test_batch in asr_model.test_dataloader():
        if can_gpu:
            test_batch = [x.cuda() for x in test_batch]
        with autocast():
            log_probs, encoded_len, greedy_predictions = asr_model(
                input_signal=test_batch[0], input_signal_length=test_batch[1])
        hypotheses = wer.ctc_decoder_predictions_tensor(greedy_predictions)
        for batch_ind in range(greedy_predictions.shape[0]):
            reference = key
            #reference = "".join([labels_map[c] for c in test_batch[2][batch_ind].cpu().detach().numpy()])  #debug
            print(reference)  #debug
            references.append(reference)
        del test_batch
    wer_value = word_error_rate(hypotheses=hypotheses,
                                references=references)  #cer=True

    REC = '.'
    REF = '.'
    for h, r in zip(hypotheses, references):
        print("Recognized:\t{}\nReference:\t{}\n".format(h, r))
        REC = h
        REF = r
    logging.info(f"Got PER of {wer_value}. Tolerance was {args.wer_tolerance}")

    #Score Calculation, phoneme conversion
    # divide wer_value by wer_tolerance to get the ratio of correctness (and round it)
    # then multiply by 100 to get a value above 0
    # since this give the "% wrong", subtract from 100 to get "% correct"
    # this gives a positive grade to show return to the user
    score = 100.00 - (round((wer_value / args.wer_tolerance), 4) * 100)
    if score < 0.0:
        score = 0.0
    print(score)

    #Result file creation, to be accessed by JS via 'app.py'
    Results = open(datasetPath + id + '_graded.txt', 'w')
    Results.write(REC + '\n' + REF + '\n' + str(score))
    Results.close()
    return score
示例#16
0
def main():
    parser = ArgumentParser()
    parser.add_argument(
        "--asr_model",
        type=str,
        default="QuartzNet15x5Base-En",
        choices=[
            x.pretrained_model_name
            for x in EncDecCTCModel.list_available_models()
        ],
    )
    parser.add_argument(
        "--tts_model_spec",
        type=str,
        default="Tacotron2-22050Hz",
        choices=[
            x.pretrained_model_name
            for x in SpectrogramGenerator.list_available_models()
        ],
    )
    parser.add_argument(
        "--tts_model_vocoder",
        type=str,
        default="WaveGlow-22050Hz",
        choices=[
            x.pretrained_model_name for x in Vocoder.list_available_models()
        ],
    )
    parser.add_argument("--wer_tolerance",
                        type=float,
                        default=1.0,
                        help="used by test")
    parser.add_argument("--trim", action="store_true")
    parser.add_argument("--debug", action="store_true")
    args = parser.parse_args()
    torch.set_grad_enabled(False)

    if args.debug:
        logging.set_verbosity(logging.DEBUG)

    logging.info(f"Using NGC cloud ASR model {args.asr_model}")
    asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model)
    logging.info(
        f"Using NGC cloud TTS Spectrogram Generator model {args.tts_model_spec}"
    )
    tts_model_spec = SpectrogramGenerator.from_pretrained(
        model_name=args.tts_model_spec)
    logging.info(f"Using NGC cloud TTS Vocoder model {args.tts_model_vocoder}")
    tts_model_vocoder = Vocoder.from_pretrained(
        model_name=args.tts_model_vocoder)
    models = [asr_model, tts_model_spec, tts_model_vocoder]

    if torch.cuda.is_available():
        for i, m in enumerate(models):
            models[i] = m.cuda()
    for m in models:
        m.eval()

    asr_model, tts_model_spec, tts_model_vocoder = models

    parser = parsers.make_parser(
        labels=asr_model.decoder.vocabulary,
        name="en",
        unk_id=-1,
        blank_id=-1,
        do_normalize=True,
    )
    labels_map = dict([(i, asr_model.decoder.vocabulary[i])
                       for i in range(len(asr_model.decoder.vocabulary))])

    tts_input = []
    asr_references = []
    longest_tts_input = 0
    for test_str in LIST_OF_TEST_STRINGS:
        tts_parsed_input = tts_model_spec.parse(test_str)
        if len(tts_parsed_input[0]) > longest_tts_input:
            longest_tts_input = len(tts_parsed_input[0])
        tts_input.append(tts_parsed_input.squeeze())

        asr_parsed = parser(test_str)
        asr_parsed = ''.join([labels_map[c] for c in asr_parsed])
        asr_references.append(asr_parsed)

    # Pad TTS Inputs
    for i, text in enumerate(tts_input):
        pad = (0, longest_tts_input - len(text))
        tts_input[i] = torch.nn.functional.pad(text, pad, value=68)

    logging.debug(tts_input)

    # Do TTS
    tts_input = torch.stack(tts_input)
    if torch.cuda.is_available():
        tts_input = tts_input.cuda()
    specs = tts_model_spec.generate_spectrogram(tokens=tts_input)
    audio = []
    step = ceil(len(specs) / 4)
    for i in range(4):
        audio.append(
            tts_model_vocoder.convert_spectrogram_to_audio(
                spec=specs[i * step:i * step + step]))

    audio = [item for sublist in audio for item in sublist]
    audio_file_paths = []
    # Save audio
    logging.debug(f"args.trim: {args.trim}")
    for i, aud in enumerate(audio):
        aud = aud.cpu().numpy()
        if args.trim:
            aud = librosa.effects.trim(aud, top_db=40)[0]
        librosa.output.write_wav(f"{i}.wav", aud, sr=22050)
        audio_file_paths.append(str(Path(f"{i}.wav")))

    # Do ASR
    hypotheses = asr_model.transcribe(audio_file_paths)
    for i, _ in enumerate(hypotheses):
        logging.debug(f"{i}")
        logging.debug(f"ref:'{asr_references[i]}'")
        logging.debug(f"hyp:'{hypotheses[i]}'")
    wer_value = word_error_rate(hypotheses=hypotheses,
                                references=asr_references)
    if wer_value > args.wer_tolerance:
        raise ValueError(
            f"Got WER of {wer_value}. It was higher than {args.wer_tolerance}")
    logging.info(f'Got WER of {wer_value}. Tolerance was {args.wer_tolerance}')
def main(cfg: ParallelTranscriptionConfig):
    if cfg.model.endswith(".nemo"):
        logging.info("Attempting to initialize from .nemo file")
        model = ASRModel.restore_from(restore_path=cfg.model,
                                      map_location="cpu")
    elif cfg.model.endswith(".ckpt"):
        logging.info("Attempting to initialize from .ckpt file")
        model = ASRModel.load_from_checkpoint(checkpoint_path=cfg.model,
                                              map_location="cpu")
    else:
        logging.info(
            "Attempting to initialize from a pretrained model as the model name does not have the extension of .nemo or .ckpt"
        )
        model = ASRModel.from_pretrained(model_name=cfg.model,
                                         map_location="cpu")

    trainer = ptl.Trainer(**cfg.trainer)

    cfg.predict_ds.return_sample_id = True
    cfg.predict_ds = match_train_config(predict_ds=cfg.predict_ds,
                                        train_ds=model.cfg.train_ds)
    data_loader = model._setup_dataloader_from_config(cfg.predict_ds)

    os.makedirs(cfg.output_path, exist_ok=True)
    # trainer.global_rank is not valid before predict() is called. Need this hack to find the correct global_rank.
    global_rank = trainer.node_rank * trainer.num_gpus + int(
        os.environ.get("LOCAL_RANK", 0))
    output_file = os.path.join(cfg.output_path,
                               f"predictions_{global_rank}.json")
    predictor_writer = ASRPredictionWriter(dataset=data_loader.dataset,
                                           output_file=output_file)
    trainer.callbacks.extend([predictor_writer])

    predictions = trainer.predict(model=model,
                                  dataloaders=data_loader,
                                  return_predictions=cfg.return_predictions)
    if predictions is not None:
        predictions = list(itertools.chain.from_iterable(predictions))
    samples_num = predictor_writer.close_output_file()

    logging.info(
        f"Prediction on rank {global_rank} is done for {samples_num} samples and results are stored in {output_file}."
    )

    if torch.distributed.is_initialized():
        torch.distributed.barrier()

    samples_num = 0
    pred_text_list = []
    text_list = []
    if is_global_rank_zero():
        output_file = os.path.join(cfg.output_path, f"predictions_all.json")
        logging.info(
            f"Prediction files are being aggregated in {output_file}.")
        with open(output_file, 'w') as outf:
            for rank in range(trainer.world_size):
                input_file = os.path.join(cfg.output_path,
                                          f"predictions_{rank}.json")
                with open(input_file, 'r') as inpf:
                    lines = inpf.readlines()
                    for line in lines:
                        item = json.loads(line)
                        pred_text_list.append(item["pred_text"])
                        text_list.append(item["text"])
                        outf.write(json.dumps(item) + "\n")
                        samples_num += 1
        wer_cer = word_error_rate(hypotheses=pred_text_list,
                                  references=text_list,
                                  use_cer=cfg.use_cer)
        logging.info(
            f"Prediction is done for {samples_num} samples in total on all workers and results are aggregated in {output_file}."
        )
        logging.info("{} for all predictions is {:.4f}.".format(
            "CER" if cfg.use_cer else "WER", wer_cer))
示例#18
0
def main(args):
    torch.set_grad_enabled(False)
    if args.asr_model.endswith('.nemo'):
        logging.info(f"Using local ASR model from {args.asr_model}")
        asr_model = nemo_asr.models.EncDecCTCModelBPE.restore_from(
            restore_path=args.asr_model)
    else:
        logging.info(f"Using NGC cloud ASR model {args.asr_model}")
        asr_model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(
            model_name=args.asr_model)

    cfg = copy.deepcopy(asr_model._cfg)
    OmegaConf.set_struct(cfg.preprocessor, False)

    # some changes for streaming scenario
    cfg.preprocessor.dither = 0.0
    cfg.preprocessor.pad_to = 0

    if cfg.preprocessor.normalize != "per_feature":
        logging.error(
            "Only EncDecRNNTBPEModel models trained with per_feature normalization are supported currently"
        )

    device = args.device
    if device is None:
        if torch.cuda.is_available():
            device = 'cuda'
        else:
            device = 'cpu'

    logging.info(f"Inference will be done on device : {device}")

    # Disable config overwriting
    OmegaConf.set_struct(cfg.preprocessor, True)
    asr_model.freeze()
    asr_model = asr_model.to(device)

    # Change Decoding Config
    decoding_cfg = asr_model.cfg.decoding
    with open_dict(decoding_cfg):
        if args.stateful_decoding:
            decoding_cfg.strategy = "greedy"
        else:
            decoding_cfg.strategy = "greedy_batch"

        decoding_cfg.preserve_alignments = True  # required to compute the middle token for transducers.
        decoding_cfg.fused_batch_size = -1  # temporarily stop fused batch during inference.

    asr_model.change_decoding_strategy(decoding_cfg)

    feature_stride = cfg.preprocessor['window_stride']
    model_stride_in_secs = feature_stride * args.model_stride
    total_buffer = args.total_buffer_in_secs

    chunk_len = float(args.chunk_len_in_secs)

    tokens_per_chunk = math.ceil(chunk_len / model_stride_in_secs)
    mid_delay = math.ceil(
        (chunk_len + (total_buffer - chunk_len) / 2) / model_stride_in_secs)
    print("Tokens per chunk :", tokens_per_chunk, "Min Delay :", mid_delay)

    if args.merge_algo == 'middle':
        frame_asr = BatchedFrameASRRNNT(
            asr_model=asr_model,
            frame_len=chunk_len,
            total_buffer=args.total_buffer_in_secs,
            batch_size=args.batch_size,
            max_steps_per_timestep=args.max_steps_per_timestep,
            stateful_decoding=args.stateful_decoding,
        )

    elif args.merge_algo == 'lcs':
        frame_asr = LongestCommonSubsequenceBatchedFrameASRRNNT(
            asr_model=asr_model,
            frame_len=chunk_len,
            total_buffer=args.total_buffer_in_secs,
            batch_size=args.batch_size,
            max_steps_per_timestep=args.max_steps_per_timestep,
            stateful_decoding=args.stateful_decoding,
            alignment_basepath=args.lcs_alignment_dir,
        )
        # Set the LCS algorithm delay.
        frame_asr.lcs_delay = math.floor(
            ((total_buffer - chunk_len)) / model_stride_in_secs)

    else:
        raise ValueError(
            "Invalid choice of merge algorithm for transducer buffered inference."
        )

    hyps, refs, wer = get_wer_feat(
        mfst=args.test_manifest,
        asr=frame_asr,
        tokens_per_chunk=tokens_per_chunk,
        delay=mid_delay,
        model_stride_in_secs=model_stride_in_secs,
        batch_size=args.batch_size,
    )
    logging.info(
        f"WER is {round(wer, 4)} when decoded with a delay of {round(mid_delay*model_stride_in_secs, 2)}s"
    )

    if args.output_path is not None:

        fname = (os.path.splitext(os.path.basename(args.asr_model))[0] + "_" +
                 os.path.splitext(os.path.basename(args.test_manifest))[0] +
                 "_" + str(args.chunk_len_in_secs) + "_" +
                 str(int(total_buffer * 1000)) + "_" + args.merge_algo +
                 ".json")

        hyp_json = os.path.join(args.output_path, fname)
        os.makedirs(args.output_path, exist_ok=True)
        with open(hyp_json, "w") as out_f:
            for i, hyp in enumerate(hyps):
                record = {
                    "pred_text":
                    hyp,
                    "text":
                    refs[i],
                    "wer":
                    round(
                        word_error_rate(hypotheses=[hyp], references=[refs[i]])
                        * 100, 2),
                }
                out_f.write(json.dumps(record) + '\n')
示例#19
0
def main():
    parser = ArgumentParser()
    """Training arguments"""
    parser.add_argument("--asr_model",
                        type=str,
                        default="QuartzNet15x5Base-En",
                        required=True,
                        help="Pass: '******'")
    parser.add_argument("--dataset",
                        type=str,
                        required=True,
                        help="path to evaluation data")
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument(
        "--normalize_text",
        default=True,
        type=bool,
        help="Normalize transcripts or not. Set to False for non-English.")
    parser.add_argument("--shuffle",
                        action='store_true',
                        help="Shuffle test data.")
    """Calibration arguments"""
    parser.add_argument("--load",
                        type=str,
                        default=None,
                        help="load path for the synthetic data")
    parser.add_argument(
        "--percentile",
        type=float,
        default=None,
        help="Max/min percentile for outlier handling. e.g., 99.9")
    """Quantization arguments"""
    parser.add_argument("--weight_bit",
                        type=int,
                        default=8,
                        help="quantization bit for weights")
    parser.add_argument("--act_bit",
                        type=int,
                        default=8,
                        help="quantization bit for activations")
    parser.add_argument("--dynamic",
                        action='store_true',
                        help="Dynamic quantization mode.")
    parser.add_argument("--no_quant",
                        action='store_true',
                        help="No quantization mode.")
    """Debugging arguments"""
    parser.add_argument("--eval_early_stop",
                        type=int,
                        default=None,
                        help="early stop for debugging")
    parser.add_argument("--calib_early_stop",
                        type=int,
                        default=None,
                        help="early stop calibration")

    args = parser.parse_args()

    torch.set_grad_enabled(False)

    if args.asr_model.endswith('.nemo'):
        logging.info(f"Using local ASR model from {args.asr_model}")
        asr_model = EncDecCTCModel.restore_from(restore_path=args.asr_model)
    else:
        logging.info(f"Using NGC cloud ASR model {args.asr_model}")
        asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model)

    asr_model.setup_test_data(
        test_data_config={
            'sample_rate': 16000,
            'manifest_filepath': args.dataset,
            'labels': asr_model.decoder.vocabulary,
            'batch_size': args.batch_size,
            'normalize_transcripts': args.normalize_text,
            'shuffle': args.shuffle,
        })

    if args.load is not None:
        print('Data loaded from %s' % args.load)
        with open(args.load, 'rb') as f:
            distilled_data = pickle.load(f)
        synthetic_batch_size, _, synthetic_seqlen = distilled_data[0].shape
    else:
        assert args.dynamic, \
                "synthetic data must be loaded unless running with the dynamic quantization mode"

    ############################## Calibration #####################################

    torch.set_grad_enabled(False)  # disable backward graph generation
    asr_model.eval()  # evaluation mode
    asr_model.set_quant_bit(args.weight_bit, mode='weight')
    asr_model.set_quant_bit(args.act_bit, mode='act')

    # set percentile
    if args.percentile is not None:
        qm.set_percentile(asr_model, args.percentile)

    if args.no_quant:
        asr_model.set_quant_mode('none')
    else:
        asr_model.encoder.bn_folding()  # BN folding

    # if not dynamic quantization, calibrate min/max/range for the activations using synthetic data
    # if dynamic, we can skip calibration
    if not args.dynamic:
        print('Calibrating...')
        qm.calibrate(asr_model)
        length = torch.tensor([synthetic_seqlen] * synthetic_batch_size).cuda()
        for batch_idx, inputs in enumerate(distilled_data):
            if args.calib_early_stop is not None and batch_idx == args.calib_early_stop:
                break
            inputs = inputs.cuda()
            encoded, encoded_len, encoded_scaling_factor = asr_model.encoder(
                audio_signal=inputs, length=length)
            log_probs = asr_model.decoder(
                encoder_output=encoded,
                encoder_output_scaling_factor=encoded_scaling_factor)

    ############################## Evaluation  #####################################

    print('Evaluating...')
    qm.evaluate(asr_model)

    qm.set_dynamic(
        asr_model,
        args.dynamic)  # if dynamic quantization, this will be enabled
    labels_map = dict([(i, asr_model.decoder.vocabulary[i])
                       for i in range(len(asr_model.decoder.vocabulary))])
    wer = WER(vocabulary=asr_model.decoder.vocabulary)
    hypotheses = []
    references = []
    progress_bar = tqdm(asr_model.test_dataloader())

    for i, test_batch in enumerate(progress_bar):
        if i == args.eval_early_stop:
            break
        test_batch = [x.cuda().float() for x in test_batch]
        with autocast():
            log_probs, encoded_len, greedy_predictions = asr_model(
                input_signal=test_batch[0], input_signal_length=test_batch[1])
        hypotheses += wer.ctc_decoder_predictions_tensor(greedy_predictions)
        for batch_ind in range(greedy_predictions.shape[0]):
            reference = ''.join([
                labels_map[c]
                for c in test_batch[2][batch_ind].cpu().detach().numpy()
            ])
            references.append(reference)
        del test_batch
    wer_value = word_error_rate(hypotheses=hypotheses, references=references)
    print('WER:', wer_value)
示例#20
0
            print("\ttags=", " ".join(predicted_tags[i]))
            print("\tpred=", predictions[i])
            print("\tsemiotic=", predicted_semiotic[i])
            print("\tref=", references[i][-1])  # last reference is actual reference
            sentences_with_errors_on_digits += 1
        elif ok_all:
            correct_sentences_disregarding_space += 1
        elif args.print_other_errors:
            print("other error:")
            print("\tinput=", " ".join(inputs[i]))
            print("\ttags=", " ".join(predicted_tags[i]))
            print("\tpred=", predictions[i])
            print("\tsemiotic=", predicted_semiotic[i])
            print("\tref=", references[i][-1])  # last reference is actual reference

    wer = word_error_rate(refs_for_wer, preds_for_wer)
    print("WER: ", wer)
    print(
        "Sentence accuracy: ",
        correct_sentences_disregarding_space / (len(inputs) - len(skip_ids)),
        correct_sentences_disregarding_space,
    )
    print(
        "digit errors: ",
        sentences_with_errors_on_digits / (len(inputs) - len(skip_ids)),
        sentences_with_errors_on_digits,
    )
    print(
        "other errors: ",
        (len(inputs) - len(skip_ids) - correct_sentences_disregarding_space - sentences_with_errors_on_digits)
        / (len(inputs) - len(skip_ids)),