Beispiel #1
0
    def add_evaluation(self, utt_id, ref, pred, bpe_symbol=None):
        if not isinstance(utt_id, str):
            raise TypeError('utt_id must be a string(got {})'.format(
                type(utt_id)))
        if not isinstance(ref, str):
            raise TypeError('ref must be a string (got {})'.format(type(ref)))
        if not isinstance(pred, str):
            raise TypeError('pred must be a string(got {})'.format(type(pred)))

        # filter out any non_lang_syms from ref and pred
        non_lang_syms = getattr(self.dictionary, 'non_lang_syms', None)
        assert non_lang_syms is None or isinstance(non_lang_syms, list)
        if non_lang_syms is not None and len(non_lang_syms) > 0:
            ref_list, pred_list = ref.strip().split(), pred.strip().split()
            ref = ' '.join([x for x in ref_list if x not in non_lang_syms])
            pred = ' '.join([x for x in pred_list if x not in non_lang_syms])

        # char level counts
        _, _, counter = speech_utils.edit_distance(
            ref.strip().split(),
            pred.strip().split(),
        )
        self.char_counter += counter

        # word level counts
        ref_words = self.dictionary.tokens_to_sentence(ref,
                                                       use_unk_sym=False,
                                                       bpe_symbol=bpe_symbol)
        pred_words = self.dictionary.tokens_to_sentence(pred,
                                                        bpe_symbol=bpe_symbol)

        # filter words according to self.word_filters (support re.sub only)
        for pattern, repl in self.word_filters:
            ref_words = re.sub(pattern, repl, ref_words)
            pred_words = re.sub(pattern, repl, pred_words)

        ref_word_list, pred_word_list = ref_words.split(), pred_words.split()
        _, steps, counter = speech_utils.edit_distance(
            ref_word_list,
            pred_word_list,
        )
        self.word_counter += counter
        assert utt_id not in self.aligned_results, \
            'Duplicated utterance id detected: {}'.format(utt_id)
        self.aligned_results[utt_id] = speech_utils.aligned_print(
            ref_word_list,
            pred_word_list,
            steps,
        )
Beispiel #2
0
def main(args):
    non_lang_syms = []
    if args.non_lang_syms is not None:
        with open(args.non_lang_syms, 'r', encoding='utf-8') as f:
            non_lang_syms = [x.rstrip() for x in f.readlines()]

    word_filters = []
    if args.wer_output_filter is not None:
        with open(args.wer_output_filter, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if line.startswith('#!') or line == '':
                    continue
                elif line.startswith('s/'):
                    m = re.match(r's/(\S+)/(\w*)/g', line)
                    assert m is not None
                    word_filters.append([m.group(1), m.group(2)])
                elif line.startswith('s:'):
                    m = re.match(r's:(\S+):(\w*):g', line)
                    assert m is not None
                    word_filters.append([m.group(1), m.group(2)])
                else:
                    logger.warning(
                        'Unsupported pattern: "{}". Ignoring it.'.format(line))

    refs = {}
    with open(args.ref_text, 'r', encoding='utf-8') as f:
        for line in f:
            utt_id, text = line.strip().split(None, 1)
            assert utt_id not in refs, utt_id
            refs[utt_id] = text

    wer_counter = Counter()
    with open(args.hyp_text, 'r', encoding='utf-8') as f:
        for line in f:
            utt_id, text = line.strip().split(None, 1)
            assert utt_id in refs, utt_id
            ref, hyp = refs[utt_id], text

            # filter words according to word_filters (support re.sub only)
            for pattern, repl in word_filters:
                ref = re.sub(pattern, repl, ref)
                hyp = re.sub(pattern, repl, hyp)

            # filter out any non_lang_syms from ref and hyp
            ref_list = [x for x in ref.split() if x not in non_lang_syms]
            hyp_list = [x for x in hyp.split() if x not in non_lang_syms]

            _, _, counter = edit_distance(ref_list, hyp_list)
            wer_counter += counter

    assert wer_counter['words'] > 0
    wer = float(wer_counter['sub'] + wer_counter['ins'] +
                wer_counter['del']) / wer_counter['words'] * 100
    sub = float(wer_counter['sub']) / wer_counter['words'] * 100
    ins = float(wer_counter['ins']) / wer_counter['words'] * 100
    dlt = float(wer_counter['del']) / wer_counter['words'] * 100

    print('WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%, #words={:d}'.
          format(wer, sub, ins, dlt, wer_counter['words']))
Beispiel #3
0
    def add_evaluation(self, utt_id, ref, pred):
        if not isinstance(utt_id, str):
            raise TypeError("utt_id must be a string(got {})".format(
                type(utt_id)))
        if not isinstance(ref, str):
            raise TypeError("ref must be a string (got {})".format(type(ref)))
        if not isinstance(pred, str):
            raise TypeError("pred must be a string(got {})".format(type(pred)))

        # filter out any non_lang_syms from ref and pred
        non_lang_syms = getattr(self.dictionary, "non_lang_syms", None)
        assert non_lang_syms is None or isinstance(non_lang_syms, list)
        if non_lang_syms is not None and len(non_lang_syms) > 0:
            ref_list, pred_list = ref.strip().split(), pred.strip().split()
            ref = " ".join([x for x in ref_list if x not in non_lang_syms])
            pred = " ".join([x for x in pred_list if x not in non_lang_syms])

        # char level counts
        _, _, counter = speech_utils.edit_distance(
            ref.strip().split(),
            pred.strip().split(),
        )
        self.char_counter += counter

        # word level counts
        ref_words = self.dictionary.wordpiece_decode(ref)
        pred_words = self.dictionary.wordpiece_decode(pred)

        # filter words according to self.word_filters (support re.sub only)
        for pattern, repl in self.word_filters:
            ref_words = re.sub(pattern, repl, ref_words)
            pred_words = re.sub(pattern, repl, pred_words)

        ref_word_list, pred_word_list = ref_words.split(), pred_words.split()
        _, steps, counter = speech_utils.edit_distance(
            ref_word_list,
            pred_word_list,
        )
        self.word_counter += counter
        assert (utt_id not in self.aligned_results
                ), "Duplicated utterance id detected: {}".format(utt_id)
        self.aligned_results[utt_id] = speech_utils.aligned_print(
            ref_word_list,
            pred_word_list,
            steps,
        )
    def test_edit_distance(self):
        ref, hyp = [], []
        dist, steps, counter = utils.edit_distance(ref, hyp)
        self.assertEqual(
            counter,
            Counter({'words': 0, 'corr': 0, 'sub': 0, 'ins': 0, 'del': 0}),
        )
        self.assertEqual(steps, [])

        ref, hyp = ['a', 'b', 'c'], []
        dist, steps, counter = utils.edit_distance(ref, hyp)
        self.assertEqual(
            counter,
            Counter({'words': 3, 'corr': 0, 'sub': 0, 'ins': 0, 'del': 3}),
        )
        self.assertEqual(steps, ['del', 'del', 'del'])

        ref, hyp = ['a', 'b', 'c'], ['a', 'b', 'c']
        dist, steps, counter = utils.edit_distance(ref, hyp)
        self.assertEqual(
            counter,
            Counter({'words': 3, 'corr': 3, 'sub': 0, 'ins': 0, 'del': 0}),
        )
        self.assertEqual(steps, ['corr', 'corr', 'corr'])

        ref, hyp = ['a', 'b', 'c'], ['d', 'b', 'c', 'e', 'f']
        dist, steps, counter = utils.edit_distance(ref, hyp)
        self.assertEqual(
            counter,
            Counter({'words': 3, 'corr': 2, 'sub': 1, 'ins': 2, 'del': 0}),
        )
        self.assertEqual(steps, ['sub', 'corr', 'corr', 'ins', 'ins'])

        ref, hyp = ['b', 'c', 'd', 'e', 'f', 'h'], \
            ['d', 'b', 'c', 'e', 'f', 'g']
        dist, steps, counter = utils.edit_distance(ref, hyp)
        self.assertEqual(
            counter,
            Counter({'words': 6, 'corr': 4, 'sub': 1, 'ins': 1, 'del': 1}),
        )
        self.assertEqual(
            steps,
            ['ins', 'corr', 'corr', 'del', 'corr', 'corr', 'sub'],
        )
Beispiel #5
0
    def test_edit_distance(self):
        ref, hyp = [], []
        dist, steps, counter = utils.edit_distance(ref, hyp)
        self.assertEqual(
            counter,
            Counter({
                "words": 0,
                "corr": 0,
                "sub": 0,
                "ins": 0,
                "del": 0
            }),
        )
        self.assertEqual(steps, [])

        ref, hyp = ["a", "b", "c"], []
        dist, steps, counter = utils.edit_distance(ref, hyp)
        self.assertEqual(
            counter,
            Counter({
                "words": 3,
                "corr": 0,
                "sub": 0,
                "ins": 0,
                "del": 3
            }),
        )
        self.assertEqual(steps, ["del", "del", "del"])

        ref, hyp = ["a", "b", "c"], ["a", "b", "c"]
        dist, steps, counter = utils.edit_distance(ref, hyp)
        self.assertEqual(
            counter,
            Counter({
                "words": 3,
                "corr": 3,
                "sub": 0,
                "ins": 0,
                "del": 0
            }),
        )
        self.assertEqual(steps, ["corr", "corr", "corr"])

        ref, hyp = ["a", "b", "c"], ["d", "b", "c", "e", "f"]
        dist, steps, counter = utils.edit_distance(ref, hyp)
        self.assertEqual(
            counter,
            Counter({
                "words": 3,
                "corr": 2,
                "sub": 1,
                "ins": 2,
                "del": 0
            }),
        )
        self.assertEqual(steps, ["sub", "corr", "corr", "ins", "ins"])

        ref, hyp = ["b", "c", "d", "e", "f",
                    "h"], ["d", "b", "c", "e", "f", "g"]
        dist, steps, counter = utils.edit_distance(ref, hyp)
        self.assertEqual(
            counter,
            Counter({
                "words": 6,
                "corr": 4,
                "sub": 1,
                "ins": 1,
                "del": 1
            }),
        )
        self.assertEqual(
            steps,
            ["ins", "corr", "corr", "del", "corr", "corr", "sub"],
        )
Beispiel #6
0
def main(args):
    non_lang_syms = []
    if args.non_lang_syms is not None:
        with open(args.non_lang_syms, "r", encoding="utf-8") as f:
            non_lang_syms = [x.rstrip() for x in f.readlines()]

    word_filters = []
    if args.wer_output_filter is not None:
        with open(args.wer_output_filter, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if line.startswith("#!") or line == "":
                    continue
                elif line.startswith("s/"):
                    m = re.match(r"s/(\S+)/(\w*)/g", line)
                    assert m is not None
                    word_filters.append([m.group(1), m.group(2)])
                elif line.startswith("s:"):
                    m = re.match(r"s:(\S+):(\w*):g", line)
                    assert m is not None
                    word_filters.append([m.group(1), m.group(2)])
                else:
                    logger.warning(
                        "Unsupported pattern: '{}'. Ignoring it.".format(line))

    refs = {}
    with open(args.ref_text, "r", encoding="utf-8") as f:
        for line in f:
            utt_id, text = line.strip().split(None, 1)
            assert utt_id not in refs, utt_id
            refs[utt_id] = text

    wer_counter = Counter()
    with open(args.hyp_text, "r", encoding="utf-8") as f:
        for line in f:
            utt_id, text = line.strip().split(None, 1)
            assert utt_id in refs, utt_id
            ref, hyp = refs[utt_id], text

            # filter words according to word_filters (support re.sub only)
            for pattern, repl in word_filters:
                ref = re.sub(pattern, repl, ref)
                hyp = re.sub(pattern, repl, hyp)

            # filter out any non_lang_syms from ref and hyp
            ref_list = [x for x in ref.split() if x not in non_lang_syms]
            hyp_list = [x for x in hyp.split() if x not in non_lang_syms]

            _, _, counter = edit_distance(ref_list, hyp_list)
            wer_counter += counter

    assert wer_counter["words"] > 0
    wer = (
        float(wer_counter["sub"] + wer_counter["ins"] + wer_counter["del"]) /
        wer_counter["words"] * 100)
    sub = float(wer_counter["sub"]) / wer_counter["words"] * 100
    ins = float(wer_counter["ins"]) / wer_counter["words"] * 100
    dlt = float(wer_counter["del"]) / wer_counter["words"] * 100

    print("WER={:.2f}%, Sub={:.2f}%, Ins={:.2f}%, Del={:.2f}%, #words={:d}".
          format(wer, sub, ins, dlt, wer_counter["words"]))