Exemplo n.º 1
0
def process_evaluation_epoch(
    global_vars, metrics=('loss', 'bpc', 'ppl'), calc_wer=False, mode='eval', tag='none',
):
    tag = '_'.join(tag.lower().strip().split())
    return_dict = {}
    for metric in metrics:
        value = torch.mean(torch.stack(global_vars[metric])).item()
        return_dict[f'metric/{mode}_{metric}_{tag}'] = value

    # TODO: Delete?
    bpc = return_dict[f'metric/{mode}_bpc_{tag}']
    return_dict[f'metric/{mode}_ppl_{tag}'] = 2 ** (bpc * ENG_MWN)

    if calc_wer:
        transcript_texts = list(chain(*global_vars['transcript_texts']))
        prediction_texts = list(chain(*global_vars['prediction_texts']))

        logging.info(f'Ten examples (transcripts and predictions)')
        logging.info(transcript_texts[:10])
        logging.info(prediction_texts[:10])

        wer = word_error_rate(hypotheses=prediction_texts, references=transcript_texts)
        return_dict[f'metric/{mode}_wer_{tag}'] = wer

    logging.info(pformat(return_dict))

    return return_dict
Exemplo n.º 2
0
 def data_fn(
     transcript,
     audio_dur,
     wav_data,
     caller_name,
     aud_seg,
     fname,
     audio_path,
     num_datapoints,
     rel_data_path,
 ):
     pretrained_result = transcriber_pretrained(aud_seg.raw_data)
     pretrained_wer = word_error_rate([transcript], [pretrained_result])
     wav_plot_path = (
         dataset_dir / Path("wav_plots") / Path(fname).with_suffix(".png")
     )
     if not wav_plot_path.exists():
         plot_seg(wav_plot_path, audio_path)
     return {
         "audio_filepath": str(rel_data_path),
         "duration": round(audio_dur, 1),
         "text": transcript,
         "real_idx": num_datapoints,
         "audio_path": audio_path,
         "spoken": transcript,
         "caller": caller_name,
         "utterance_id": fname,
         "pretrained_asr": pretrained_result,
         "pretrained_wer": pretrained_wer,
         "plot_path": str(wav_plot_path),
     }
Exemplo n.º 3
0
def preprocess_datapoint(idx, rel_root, sample, use_domain_asr,
                         annotation_only, enable_plots):
    from pydub import AudioSegment
    from nemo.collections.asr.metrics import word_error_rate
    from jasper.client import transcribe_gen

    try:
        res = dict(sample)
        res["real_idx"] = idx
        audio_path = rel_root / Path(sample["audio_filepath"])
        res["audio_path"] = str(audio_path)
        if use_domain_asr:
            res["spoken"] = alnum_to_asr_tokens(res["text"])
        else:
            res["spoken"] = res["text"]
        res["utterance_id"] = audio_path.stem
        if not annotation_only:
            transcriber_pretrained = transcribe_gen(asr_port=8044)

            aud_seg = (
                AudioSegment.from_file_using_temporary_files(audio_path).
                set_channels(1).set_sample_width(2).set_frame_rate(24000))
            res["pretrained_asr"] = transcriber_pretrained(aud_seg.raw_data)
            res["pretrained_wer"] = word_error_rate([res["text"]],
                                                    [res["pretrained_asr"]])
            if use_domain_asr:
                transcriber_speller = transcribe_gen(asr_port=8045)
                res["domain_asr"] = transcriber_speller(aud_seg.raw_data)
                res["domain_wer"] = word_error_rate([res["spoken"]],
                                                    [res["pretrained_asr"]])
        if enable_plots:
            wav_plot_path = (rel_root / Path("wav_plots") /
                             Path(audio_path.name).with_suffix(".png"))
            if not wav_plot_path.exists():
                plot_seg(wav_plot_path, audio_path)
            res["plot_path"] = str(wav_plot_path)
        return res
    except BaseException as e:
        print(f'failed on {idx}: {sample["audio_filepath"]} with {e}')
Exemplo n.º 4
0
def eval_epochs_done_callback_wer(global_vars):
    eval_loss = np.mean(global_vars["eval_loss"])
    all_ref = []
    for r in global_vars["ref"]:
        all_ref += r
    all_sys = []
    for s in global_vars["sys"]:
        all_sys += s
    ref = all_ref
    sys = all_sys
    eval_wer = word_error_rate(ref, sys)
    for i in range(3):
        sent_id = np.random.randint(len(sys))
        print("Ground truth: {0}\n".format(ref[sent_id]))
        print("Translation:  {0}\n".format(sys[sent_id]))

    print("Validation loss: {0}".format(np.round(eval_loss, 3)))
    print("Validation WER: {0}".format(eval_wer))
    global_vars["eval_loss"] = []
    global_vars["ref"] = []
    global_vars["sys"] = []

    return dict({"eval_loss": eval_loss, "eval_wer": eval_wer})