Example #1
0
    def _compute(self,
                 predictions=None,
                 references=None,
                 concatenate_texts=False):
        if type(predictions) != type(references):
            raise ValueError(
                f"`predictions` {predictions} are of type {type(predictions)}, "
                f" while `targets` {references} are of type {type(references)}. "
                "Make sure `predictions` and `targets` are of the same type.")

        inputs_are_lists = isinstance(predictions,
                                      (list, tuple)) and isinstance(
                                          references, (list, tuple))
        if inputs_are_lists and (type(predictions[0]) != type(
                references[0])):  # noqa: E721
            raise ValueError(
                f"`predictions` {predictions} is a list/tuple of type {type(predictions[0])}, "
                f" while `targets` {references} is a list/tuple of type {type(references[0])}. "
                "Make sure `predictions` and `targets` are a list/tuple of the same type."
            )

        if concatenate_texts:
            return compute_measures(references, predictions)["wer"]
        else:
            incorrect = 0
            total = 0
            for prediction, reference in zip(predictions, references):
                measures = compute_measures(reference, prediction)
                incorrect += measures["substitutions"] + measures[
                    "deletions"] + measures["insertions"]
                total += measures["substitutions"] + measures[
                    "deletions"] + measures["hits"]
            return incorrect / total
Example #2
0
    def test_different_sentence_length(self):
        cases = [
            (
                ["hello", "this", "sentence", "is fractured"],
                ["this sentence"],
                _m(0.6, 0.6, 0.6),
            ),
            (
                "i am a short ground truth",
                "i am a considerably longer and very much incorrect hypothesis",
                _m(7 / 6, 0.7, 0.85),
            ),
        ]

        self._apply_test_on(cases)

        ground_truth = [
            "i like monthy python",
            "what do you mean african or european swallow",
        ]
        hypothesis = ["i like", "python", "what you mean", "or swallow"]
        x = jiwer.compute_measures(ground_truth, hypothesis)

        # is equivalent to

        ground_truth = (
            "i like monthy python what do you mean african or european swallow"
        )
        hypothesis = "i like python what you mean or swallow"
        y = jiwer.compute_measures(ground_truth, hypothesis)

        self.assertDictAlmostEqual(x, y, delta=1e-9)
Example #3
0
def calculate_measures(groundtruth, transcription):
    """Calculate character/word measures (hits, subs, inserts, deletes) for one given sentence"""
    groundtruth = normalize_sentence(groundtruth)
    transcription = normalize_sentence(transcription)

    #cer = ed.eval(transcription, groundtruth) / len(groundtruth)
    c_result = jiwer.compute_measures([c for c in groundtruth if c != " "],
                                      [c for c in transcription if c != " "])
    w_result = jiwer.compute_measures(groundtruth, transcription)

    return c_result, w_result, groundtruth, transcription
Example #4
0
def compute_wer(predictions=None, references=None, concatenate_texts=False):
    if concatenate_texts:
        return compute_measures(references, predictions)
    else:
        incorrect = 0
        total = 0
        for prediction, reference in zip(predictions, references):
            measures = compute_measures(reference, prediction)
            incorrect += measures["substitutions"] + measures[
                "deletions"] + measures["insertions"]
            total += measures["substitutions"] + measures[
                "deletions"] + measures["hits"]
        return measures
Example #5
0
def analyze_da_realization(sentence_plan_path, realized_acts_path):
    with open(sentence_plan_path, 'r') as valid_freq_plan:
        sentence_plans = [line.strip() for line in valid_freq_plan]

    with open(realized_acts_path, 'r') as realized_acts:
        realized_acts = [line.strip() for line in realized_acts]

    print(jiwer.compute_measures(sentence_plans, realized_acts))

    all_plan_das = []
    all_real_das = []

    for plan, realized in zip(sentence_plans, realized_acts):
        plan_acts = plan.split(" ")
        real_das = realized.split(" ")

        len_diff = len(plan_acts) - len(real_das)

        if len_diff < 0:
            plan_acts = plan_acts + ["None"] * -len_diff
        elif len_diff > 0:
            real_das = real_das + ["None"] * len_diff

        all_plan_das += plan_acts
        all_real_das += real_das

    labels = list(set(all_plan_das))
    datoi = {label: i for i, label in enumerate(labels)}
    plan_da_idx = [datoi[label] for label in all_plan_das]
    real_da_idx = [datoi[label] for label in all_real_das]
    print(classification_report(plan_da_idx, real_da_idx, target_names=labels))
Example #6
0
def batched_asr_metrics(hypos, truths, metrics, skip_trans_examples=False):
    results = [
        compute_measures(truth, hypothesis)
        for truth, hypothesis in zip(truths, hypos)
        if truth.replace(' ', '') != "" and not (skip_trans_examples and "_trans" in hypothesis)
    ]
    return join_results(results, metrics)
Example #7
0
def compute_cer(predictions, references, concatenate_texts=False):

    if concatenate_texts:
        return jiwer.wer(
            references,
            predictions,
            truth_transform=cer_transform,
            hypothesis_transform=cer_transform,
        )

    incorrect = 0
    total = 0
    for prediction, reference in zip(predictions, references):
        measures = jiwer.compute_measures(
            reference,
            prediction,
            truth_transform=cer_transform,
            hypothesis_transform=cer_transform,
        )
        incorrect += measures["substitutions"] + measures[
            "deletions"] + measures["insertions"]
        total += measures["substitutions"] + measures["deletions"] + measures[
            "hits"]

    return incorrect / total
 def compute_score(preds, labels, score_type="wer"):
     incorrect = 0
     total = 0
     for prediction, reference in zip(preds, labels):
         if score_type == "wer":
             measures = compute_measures(reference, prediction)
         elif score_type == "cer":
             measures = compute_measures(
                 reference,
                 prediction,
                 truth_transform=cer_transform,
                 hypothesis_transform=cer_transform)
         incorrect += measures["substitutions"] + measures[
             "deletions"] + measures["insertions"]
         total += measures["substitutions"] + measures[
             "deletions"] + measures["hits"]
     return incorrect / total
Example #9
0
    def update(self, predictions: Union[List[str], str], targets: Union[List[str]], transform=None) -> None:
        """
        Function adds predictions and targets computation of nlp metrics.

        Args:
            predictions (Union[str,List[str]]):
            targets (Union[List[str],str]):

        """
        if transform:
            mes = jiwer.compute_measures(truth=targets, hypothesis=predictions, truth_transform=transform, hypothesis_transform=transform)
        else:
            mes = jiwer.compute_measures(truth=targets, hypothesis=predictions)

        self.mer.track(mes["mer"])
        self.wer.track(mes["wer"])
        self.wil.track(mes["wil"])
Example #10
0
def calculate():

    global ground_truth
    ground_truth = request.form['gold_standard']

    global hypothesis
    hypothesis = request.form['translated_machine']

    if ((len(ground_truth) > 2) and (len(hypothesis) > 2)):
        ground_truth = re.sub(r"""[,.;@#?!&$]+\ *""",
                              " ",
                              ground_truth,
                              flags=re.VERBOSE)
        hypothesis = re.sub(r"""[,.;@#?!&$]+\ *""",
                            " ",
                            hypothesis,
                            flags=re.VERBOSE)
        measures = jiwer.compute_measures(
            ground_truth.lower().replace('<br/>', ''),
            hypothesis.lower().replace('<br/>', ''))
        wer = str("%.2f" % (measures['wer'] * 100)) + "%"
        mer = str("%.2f" % (measures['mer'] * 100)) + "%"
        wil = str("%.2f" % (measures['wil'] * 100)) + "%"
        wip = str("%.2f" % (measures['wip'] * 100)) + "%"

        r = re.sub(' +', ' ', ground_truth).split()
        h = re.sub(' +', ' ', hypothesis).split()

        # find out the manipulation steps
        d = editDistance(r, h)
        list_step = getStepList(r, h, d)

        count_s = [ele for ele in list_step if ele == 's']
        count_d = [ele for ele in list_step if ele == 'd']
        count_i = [ele for ele in list_step if ele == 'i']

        # print the result in aligned way
        result = float(d[len(r)][len(h)]) / len(r) * 100
        result = str("%.2f" % result) + "%"
        #alignedPrint(list_step, r, h, result)

        #print(ground_truth.lower().replace('<br/>', ''))
        #print(hypothesis.lower().replace('<br/>', ''))
        #print(wer, mer, wil, wip)
        return flask.render_template('index.html',
                                     gold_standard=ground_truth,
                                     translated_machine=hypothesis,
                                     wer=wer,
                                     mer=mer,
                                     wil=wil,
                                     wip=wip,
                                     cs_wer=result,
                                     sub=len(count_s),
                                     dele=len(count_d),
                                     ins=len(count_i))

    else:
        return flask.render_template('index.html')
Example #11
0
 def _apply_test_on(self, cases):
     for gt, h, correct_measures in cases:
         measures = jiwer.compute_measures(truth=gt, hypothesis=h)
         # Remove entries we are not testing against
         [
             measures.pop(k)
             for k in ["hits", "substitutions", "deletions", "insertions"]
         ]
         self.assertDictAlmostEqual(measures, correct_measures, delta=1e-16)
Example #12
0
    def hsdi(truth, hypothesis, transformation):
        from jiwer import compute_measures

        keep = ["hits", "substitutions", "deletions", "insertions"]
        out = compute_measures(
            truth=truth,
            hypothesis=hypothesis,
            truth_transform=transformation,
            hypothesis_transform=transformation,
        ).items()
        return {k: v for k, v in out if k in keep}
Example #13
0
def analyse_wer(corrected_sentences, predicted_sentences):
    word_error_rate = list()
    match_error_rate = list()
    word_info_lost = list()
    for i in range(len(corrected_sentences)):
        all_measures = jiwer.compute_measures(predicted_sentences[i],
                                              corrected_sentences[i])
        word_error_rate.append(all_measures["wer"])
        match_error_rate.append(all_measures['mer'])
        word_info_lost.append(all_measures['wil'])
    return (word_error_rate, match_error_rate, word_info_lost)
def wer_and_cer(preds, labels, concatenate_texts, config_name):
    try:
        from jiwer import compute_measures
    except ImportError:
        raise ValueError(
            f"jiwer has to be installed in order to apply the wer metric for {config_name}."
            "You can install it via `pip install jiwer`.")

    if concatenate_texts:
        wer = compute_measures(labels, preds)["wer"]

        cer = compute_measures(labels,
                               preds,
                               truth_transform=cer_transform,
                               hypothesis_transform=cer_transform)["wer"]
        return {"wer": wer, "cer": cer}
    else:

        def compute_score(preds, labels, score_type="wer"):
            incorrect = 0
            total = 0
            for prediction, reference in zip(preds, labels):
                if score_type == "wer":
                    measures = compute_measures(reference, prediction)
                elif score_type == "cer":
                    measures = compute_measures(
                        reference,
                        prediction,
                        truth_transform=cer_transform,
                        hypothesis_transform=cer_transform)
                incorrect += measures["substitutions"] + measures[
                    "deletions"] + measures["insertions"]
                total += measures["substitutions"] + measures[
                    "deletions"] + measures["hits"]
            return incorrect / total

        return {
            "wer": compute_score(preds, labels, "wer"),
            "cer": compute_score(preds, labels, "cer")
        }
Example #15
0
    def validation_step(self, batch, batch_nb):
        xs, ys, xlen, ylen = batch
        y, nll = self.model.greedy_decode(xs, xlen)

        hypothesis = self.tokenizer.decode_plus(y)
        ground_truth = self.tokenizer.decode_plus(ys.cpu().numpy())
        measures = jiwer.compute_measures(ground_truth, hypothesis)

        return {
            'val_loss': nll.mean().item(),
            'wer': measures['wer'],
            'ground_truth': ground_truth[0],
            'hypothesis': hypothesis[0]
        }
Example #16
0
def metric_for_text(text_groundtruth, text_estimated):
    # transformation = jiwer.Compose([
    #     jiwer.ToLowerCase(),
    #     jiwer.RemoveMultipleSpaces(),
    #     jiwer.RemoveWhiteSpace(replace_by_space=False),
    #     jiwer.SentencesToListOfWords(word_delimiter=" ")
    # ])
    # measures = jiwer.compute_measures(text_groundtruth, text_estimated, truth_transform=transformation,
    #                                   hypothesis_transform=transformation)
    measures = jiwer.compute_measures(text_groundtruth, text_estimated)
    wer = measures['wer']
    mer = measures['mer']
    wil = measures['wil']

    return [wer, mer, wil]
Example #17
0
 def _apply_test_on(self, cases):
     for gt, h, correct_measures in cases:
         measures = jiwer.compute_measures(
             truth=gt,
             hypothesis=h,
             truth_transform=jiwer.transformations.wer_contiguous,
             hypothesis_transform=jiwer.transformations.wer_contiguous,
         )
         # Remove entries we are not testing against
         [
             measures.pop(k)
             for k in ["hits", "substitutions", "deletions", "insertions"]
         ]
         assertDictAlmostEqual(self,
                               measures,
                               correct_measures,
                               delta=1e-16)
    def _compute(self, predictions, references):

        incorrect = 0
        total = 0
        for prediction, reference in zip(predictions, references):
            measures = jiwer.compute_measures(
                reference,
                prediction,
                truth_transform=_transform,
                hypothesis_transform=_transform,
            )
            incorrect += (
                measures["substitutions"]
                + measures["deletions"]
                + measures["insertions"]
            )
            total += (
                measures["substitutions"] + measures["deletions"] + measures["hits"]
            )

        return incorrect / total
Example #19
0
    def call_fairseq(self, df):
        data_dir, manifest_name = os.path.split(self.manifest)
        manifest_name = manifest_name.replace('.tsv', '')

        inference_command = f"fairseq-generate {data_dir} --config-yaml config.yaml --gen-subset {manifest_name} --task speech_to_text --path {self.model_path} --max-tokens 50000 --beam 5 --scoring wer --results-path {self.transcripts_dir}"
        os.system(inference_command)

        # Currently, fairseq only reports the total score for the entire dataset inference.
        # Therefore, we must take the result transcripts and recompute the score for every utterance.
        # First, parse the obtained results.
        transcripts_path = os.path.join(self.transcripts_dir,
                                        f"generate-{manifest_name}.txt")
        eval_dict = {}
        with open(transcripts_path, 'r') as f:
            curr_sample_id = ""
            for line in f.readlines():
                if line.startswith('T-'):
                    curr_sample_id = line.split('\t')[0].replace('T-', '')
                    ref = line.split('\t')[1].strip()
                elif line.startswith('D-'):
                    assert (
                        curr_sample_id == line.split('\t')[0].replace(
                            'D-', '')
                    ), "Reference transcript is not followed by a detokenized hypothesis!"
                    hyp = line.split('\t')[2].strip()
                    eval_dict[curr_sample_id] = (ref, hyp)

        # Now, compute and append the score for every utterance at the dataframe.
        df['wer'] = np.nan
        for sample_id, (ref, hyp) in eval_dict.items():
            measures = jiwer.compute_measures(ref, hyp)
            wer = measures['wer'] * 100.0

            assert (
                ref == df.loc[int(sample_id), 'tgt_text']
            ), "The reference text indicated by the sample ID in the transcripts file does not match with the one stored in the dataset!"
            df.at[int(sample_id), 'wer'] = wer

        return df
Example #20
0
def analyze():
    try:
        req_data = request.get_json()

        compose_rule_set = []
        if req_data.get('to_lower_case', False) == True:
            compose_rule_set.append(jiwer.ToLowerCase())
        if req_data.get('strip_punctuation', False) == True:
            compose_rule_set.append(jiwer.RemovePunctuation())
        if req_data.get('strip_words', False) == True:
            compose_rule_set.append(jiwer.Strip())
        if req_data.get('strip_multi_space', False) == True:
            compose_rule_set.append(jiwer.RemoveMultipleSpaces())
        word_excepts = req_data.get('t_words', '')
        if word_excepts != '':
            words = [a.strip() for a in word_excepts.split(",")]
            compose_rule_set.append(jiwer.RemoveSpecificWords(words))

        compose_rule_set.append(
            jiwer.RemoveWhiteSpace(
                replace_by_space=req_data.get('replace_whitespace', False)))

        transformation = jiwer.Compose(compose_rule_set)

        measures = jiwer.compute_measures(req_data.get('s_truth', ""),
                                          req_data.get('s_hypo', ""),
                                          truth_transform=transformation,
                                          hypothesis_transform=transformation)

        return jsonify({
            "wer": measures['wer'],
            "mer": measures['mer'],
            "wil": measures['wil']
        })
    except:
        return jsonify("API endpoint Error")
Example #21
0
def _compute_wer_metric_jiwer(preds: Union[str, List[str]], target: Union[str, List[str]]):
    return compute_measures(target, preds)["wer"]
Example #22
0
                jiwer.RemoveWhiteSpace(replace_by_space=True),
                jiwer.SentencesToListOfWords(),
                jiwer.RemovePunctuation(),
                jiwer.RemoveEmptyStrings(),
                jiwer.SubstituteRegexes({r"ё": r"е"})
            ])
            gt = transformation([ground_truth])
            hp = transformation([hypothesis])

            gt, hp = replace_pairs(gt, hp)
            hp, gt = replace_pairs(hp, gt)

            wer(gt, hp)

            r = jiwer.compute_measures(
                gt,
                hp
            )
            print(f"\nWER:{r['wer'] * 100:.3f}\t\tS:{r['S']} D:{r['D']} H:{r['H']} I:{r['I']}\n")

            S += r["S"]
            D += r["D"]
            I += r["I"]
            H += r["H"]

        insertions = kd.query_text(0, 100000000)

        print(f"Лишние слова: {kd.query_text(0, 100000000).split(' ')}, I:{len(insertions.split(' '))}")
        I += len(insertions.split(' '))

        stop_time = time.time()
Example #23
0
    def call_huggingface(self, df):
        assert self.model_url != '', "Error! A model URL is needed for HuggingFace scoring, but --asr_download_model is empty"
        if self.tokenizer_url == '':
            print(
                f"Setting empty --tokenizer_url field identically to --asr_download_model: {self.model_url}"
            )
            self.tokenizer_url = self.model_url

        if self.scoring_sorting == 'ascending':
            df = df.sort_values(by=['n_frames']).reset_index(drop=True)
        elif self.scoring_sorting == 'descending':
            df = df.sort_values(by=['n_frames'],
                                ascending=False).reset_index(drop=True)
        elif self.scoring_sorting == '':
            pass
        else:
            raise NotImplementedError

        print(f"Preparing dataloader for manifest {self.manifest}...")
        dataset = AudioDataset(df)
        dataloader = DataLoader(dataset,
                                batch_size=self.batch_size,
                                collate_fn=dataset.collater,
                                num_workers=self.num_workers,
                                pin_memory=True)

        if self.hf_username == 'facebook':
            print(f"Downloading tokenizer: {self.tokenizer_url}")
            tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(
                self.tokenizer_url)

            print(f"Downloading model: {self.model_url}")
            model = Wav2Vec2ForCTC.from_pretrained(self.model_url)
        elif self.hf_username == 'speechbrain':
            if torch.cuda.is_available():
                run_opts = {"device": "cuda"}
            else:
                run_opts = {"device": "cpu"}
            print(f"Downloading model: {self.model_url}")
            model = EncoderDecoderASR.from_hparams(source=self.model_url,
                                                   run_opts=run_opts,
                                                   savedir=os.path.join(
                                                       'pretrained_models',
                                                       self.hf_modelname))
        else:
            raise NotImplementedError

        model.eval()

        print("Scoring dataset...")
        df['wer'] = np.nan

        for batch in tqdm(dataloader):
            indexes, waveforms, transcripts, wav_lens = batch

            if self.hf_username == 'facebook':
                output_logits = model(waveforms.squeeze()).logits
                predicted_ids = torch.argmax(output_logits, dim=-1)
                pred_transcripts = tokenizer.batch_decode(predicted_ids)
            elif self.hf_username == 'speechbrain':
                waveforms = waveforms.squeeze()
                #waveforms = model.audio_normalizer(waveforms, self.sampling_rate)
                pred_transcripts = model.transcribe_batch(waveforms,
                                                          wav_lens)[0]

            for index, ref in enumerate(transcripts):
                sample_id = indexes[index]
                ref = transcripts[index]
                pred = pred_transcripts[index]
                measures = jiwer.compute_measures(ref, pred)
                wer = measures['wer'] * 100.0
                assert (
                    ref == df.loc[int(sample_id), 'tgt_text']
                ), "The reference text indicated by the sample ID in the transcripts file does not match with the one stored in the dataset!"
                df.at[int(sample_id), 'wer'] = wer

        return df
Example #24
0
            all_expected.append(expected)
            all_actual.append(actual)

            if args.verbose:
                message = textwrap.dedent("""\
                {filename} WER={measures[wer]:.2f} MER={measures[mer]:.2f}
                EXPECTED: {expected!r}
                ACTUAL:   {actual!r}
                """)
                print(message.format(
                    filename=filename,
                    expected=transform(expected),
                    actual=transform(actual),
                    measures=jiwer.compute_measures(
                        hypothesis=actual,
                        truth=expected,
                        truth_transform=transform,
                        hypothesis_transform=transform),
                ))

    message = textwrap.dedent("""\
    Overall:
    WER:  {measures[wer]:.2f}
    MER:  {measures[mer]:.2f}
    WIL:  {measures[wil]:.2f}
    WIP:  {measures[wip]:.2f}
    """)
    print(message.format(measures=jiwer.compute_measures(
        hypothesis=all_actual,
        truth=all_expected,
        truth_transform=transform,
Example #25
0
def evaluate(model, dataset_paths):
    """ Evaluate model with given datasets. """
    datasets = []
    wer = []
    time_per_file = []

    for j in range(len(dataset_paths)):
        datasets.append(dataset_paths[j].rsplit('/', 1)[1])

        # Used when calculating time_per_file for each dataset
        time_per_file.append(0)
        files_total = 0

        s = 0  # num of substitutions
        d = 0  # num of deletions
        i = 0  # num of insertions
        c = 0  # num of correct words

        for folder in os.listdir(dataset_paths[j]):
            folder_path = os.path.join(dataset_paths[j], folder)
            for sub_folder in os.listdir(folder_path):
                sub_folder_path = os.path.join(folder_path, sub_folder)

                # .trans.txt contains the names and ground truths for audio files in sub_folder
                trans_path = os.path.join(
                    sub_folder_path,
                    str(folder + '-' + sub_folder + '.trans.txt'))

                with open(trans_path) as file:
                    lines = file.readlines()

                audio_files, ground_truths = zip(
                    *([x.split(' ', 1) for x in lines]))
                audio_file_paths = [
                    os.path.join(sub_folder_path, x + '.flac')
                    for x in audio_files
                ]

                # Pre-process ground truth values
                ground_truths = [x.lower().strip() for x in ground_truths]

                print('Transcribing audio files found under', sub_folder_path)
                for x in range(len(audio_file_paths)):
                    print('Working ...')
                    start_time = time.time()
                    transcription = speech_to_text(
                        model, audio_file_paths[x]).lower().strip()
                    time_per_file[j] += (time.time() - start_time)
                    files_total += 1

                    # Computing wer from all dataset audio transcriptions causes a MemoryError so it is done in parts
                    measures = jiwer.compute_measures(ground_truths[x],
                                                      transcription)

                    s += measures['substitutions']
                    d += measures['deletions']
                    i += measures['insertions']
                    c += measures['hits']

        time_per_file[j] = time_per_file[j] / files_total
        wer.append(float(s + d + i) / float(s + d + c))

    return [datasets, wer, time_per_file]
Example #26
0
def load_data(data_filename,
              disable_caching=False,
              estimate_audio=False,
              vocab=None):

    if vocab is not None:
        # load external vocab
        vocabulary_ext = {}
        with open(vocab, 'r') as f:
            for line in f:
                if '\t' in line:
                    # parse word from TSV file
                    word = line.split('\t')[0]
                else:
                    # assume each line contains just a single word
                    word = line.strip()
                vocabulary_ext[word] = 1

    if not disable_caching:
        pickle_filename = data_filename.split('.json')[0]
        json_mtime = datetime.datetime.fromtimestamp(
            os.path.getmtime(data_filename))
        timestamp = json_mtime.strftime('%Y%m%d_%H%M')
        pickle_filename += '_' + timestamp + '.pkl'
        if os.path.exists(pickle_filename):
            with open(pickle_filename, 'rb') as f:
                data, wer, cer, wmr, mwa, num_hours, vocabulary_data, alphabet, metrics_available = pickle.load(
                    f)
            if vocab is not None:
                for item in vocabulary_data:
                    item['OOV'] = item['word'] not in vocabulary_ext
            if estimate_audio:
                for item in data:
                    signal, sr = librosa.load(path=item['audio_filepath'],
                                              sr=None)
                    bw = eval_bandwidth(signal, sr)
                    item['freq_bandwidth'] = int(bw)
                    item['level_db'] = 20 * np.log10(np.max(np.abs(signal)))
            with open(pickle_filename, 'wb') as f:
                pickle.dump(
                    [
                        data, wer, cer, wmr, mwa, num_hours, vocabulary_data,
                        alphabet, metrics_available
                    ],
                    f,
                    pickle.HIGHEST_PROTOCOL,
                )
            return data, wer, cer, wmr, mwa, num_hours, vocabulary_data, alphabet, metrics_available

    data = []
    wer_dist = 0.0
    wer_count = 0
    cer_dist = 0.0
    cer_count = 0
    wmr_count = 0
    wer = 0
    cer = 0
    wmr = 0
    mwa = 0
    num_hours = 0
    vocabulary = defaultdict(lambda: 0)
    alphabet = set()
    match_vocab = defaultdict(lambda: 0)

    sm = difflib.SequenceMatcher()
    metrics_available = False
    with open(data_filename, 'r', encoding='utf8') as f:
        for line in tqdm.tqdm(f):
            item = json.loads(line)
            if not isinstance(item['text'], str):
                item['text'] = ''
            num_chars = len(item['text'])
            orig = item['text'].split()
            num_words = len(orig)
            for word in orig:
                vocabulary[word] += 1
            for char in item['text']:
                alphabet.add(char)
            num_hours += item['duration']

            if 'pred_text' in item:
                metrics_available = True
                pred = item['pred_text'].split()
                measures = jiwer.compute_measures(item['text'],
                                                  item['pred_text'])
                word_dist = measures['substitutions'] + measures[
                    'insertions'] + measures['deletions']
                char_dist = editdistance.eval(item['text'], item['pred_text'])
                wer_dist += word_dist
                cer_dist += char_dist
                wer_count += num_words
                cer_count += num_chars

                sm.set_seqs(orig, pred)
                for m in sm.get_matching_blocks():
                    for word_idx in range(m[0], m[0] + m[2]):
                        match_vocab[orig[word_idx]] += 1
                wmr_count += measures['hits']

            data.append({
                'audio_filepath': item['audio_filepath'],
                'duration': round(item['duration'], 2),
                'num_words': num_words,
                'num_chars': num_chars,
                'word_rate': round(num_words / item['duration'], 2),
                'char_rate': round(num_chars / item['duration'], 2),
                'text': item['text'],
            })
            if metrics_available:
                data[-1]['pred_text'] = item['pred_text']
                if num_words == 0:
                    num_words = 1e-9
                if num_chars == 0:
                    num_chars = 1e-9
                data[-1]['WER'] = round(word_dist / num_words * 100.0, 2)
                data[-1]['CER'] = round(char_dist / num_chars * 100.0, 2)
                data[-1]['WMR'] = round(measures['hits'] / num_words * 100.0,
                                        2)
                data[-1]['I'] = measures['insertions']
                data[-1]['D'] = measures['deletions']
                data[-1][
                    'D-I'] = measures['deletions'] - measures['insertions']

            if estimate_audio:
                signal, sr = librosa.load(path=item['audio_filepath'], sr=None)
                bw = eval_bandwidth(signal, sr)
                item['freq_bandwidth'] = int(bw)
                item['level_db'] = 20 * np.log10(np.max(np.abs(signal)))
            for k in item:
                if k not in data[-1]:
                    data[-1][k] = item[k]

    vocabulary_data = [{
        'word': word,
        'count': vocabulary[word]
    } for word in vocabulary]
    if vocab is not None:
        for item in vocabulary_data:
            item['OOV'] = item['word'] not in vocabulary_ext

    if metrics_available:
        wer = wer_dist / wer_count * 100.0
        cer = cer_dist / cer_count * 100.0
        wmr = wmr_count / wer_count * 100.0

        acc_sum = 0
        for item in vocabulary_data:
            w = item['word']
            word_accuracy = match_vocab[w] / vocabulary[w] * 100.0
            acc_sum += word_accuracy
            item['accuracy'] = round(word_accuracy, 1)
        mwa = acc_sum / len(vocabulary_data)

    num_hours /= 3600.0

    if not disable_caching:
        with open(pickle_filename, 'wb') as f:
            pickle.dump(
                [
                    data, wer, cer, wmr, mwa, num_hours, vocabulary_data,
                    alphabet, metrics_available
                ],
                f,
                pickle.HIGHEST_PROTOCOL,
            )

    return data, wer, cer, wmr, mwa, num_hours, vocabulary_data, alphabet, metrics_available
Example #27
0
 def _apply_test_on(self, cases):
     for gt, h, correct_measures in cases:
         measures = jiwer.compute_measures(truth=gt, hypothesis=h)
         self.assertDictAlmostEqual(measures, correct_measures, delta=1e-16)