Example #1
0
def add_pred(df, path, decoder_level):
    if not path:
        data = df[['src_char']].reset_index(drop=True)
        data = data.fillna('')
    else:
        data = pd.read_csv(path, sep="\n", header=None, skip_blank_lines=False)
        data = data.fillna('')
    data.columns = ["prediction"]

    df = df.reset_index(drop=True)
    if decoder_level == 'char':
        df['prediction_char'] = data["prediction"]
        df["prediction"] = data["prediction"].apply(recover_space)
    else:
        df['prediction_char'] = data["prediction"].apply(replace_space)
        df["prediction"] = data["prediction"]

    errors, matches, ref_length = [], [], []
    errors_char, matches_char, ref_length_char = [], [], []
    df['entity_errors'] = 0
    for index, row in df.iterrows():
        # token
        ref_line = row['tgt_token']
        hyp_line = row['prediction']
        ref = ref_line.split()
        hyp = hyp_line.split()
        sm = SequenceMatcher(a=ref, b=hyp)
        errors.append(get_error_count(sm))
        matches.append(get_match_count(sm))
        ref_length.append(len(ref))

        # char
        ref = row['tgt_char'].split()
        hyp = row['prediction_char'].split()
        sm = SequenceMatcher(a=ref, b=hyp)
        errors_char.append(get_error_count(sm))
        matches_char.append(get_match_count(sm))
        ref_length_char.append(len(ref))

        # entity
        df.loc[
            index,
            'entity_errors'] = 0  # sum([not clean_string(s) in hyp_line for s in row['entities_dic'].keys()])

    df['entity_count'] = df['entities_dic'].apply(len)

    df['token_errors'] = errors
    df['token_matches'] = matches
    df['token_length'] = ref_length

    df['char_errors'] = errors_char
    df['char_matches'] = matches_char
    df['char_length'] = ref_length_char

    df['sentence_count'] = 1
    df['sentence_error'] = 0
    df.loc[df['token_errors'] > 0, 'sentence_error'] = 1
    return df
Example #2
0
 def test_issue4(self):
     """ Test for error reported here:
     https://github.com/belambert/edit-distance/issues/4 """
     a = ['that', 'continuous', 'sanction', ':=', '(', 'flee', 'U', 'complain', ')', 'E', 'attendance', 'eye', '^', 'flowery', 'revelation', '^', 'ridiculous', 'destination', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>']
     b = ['continuous', ':=', '(', 'sanction', '^', 'flee', '^', 'attendance', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>']
     target_opcodes = [['delete', 0, 1, 0, 0], ['equal', 1, 2, 0, 1], ['delete', 2, 3, 0, 0], ['equal', 3, 4, 1, 2], ['equal', 4, 5, 2, 3], ['insert', 4, 4, 3, 4], ['insert', 4, 4, 4, 5], ['equal', 5, 6, 5, 6], ['replace', 6, 7, 6, 7], ['replace', 7, 8, 7, 8], ['replace', 8, 9, 8, 9], ['replace', 9, 10, 9, 10], ['replace', 10, 11, 10, 11], ['replace', 11, 12, 11, 12], ['replace', 12, 13, 12, 13], ['replace', 13, 14, 13, 14], ['replace', 14, 15, 14, 15], ['replace', 15, 16, 15, 16], ['replace', 16, 17, 16, 17], ['replace', 17, 18, 17, 18], ['equal', 18, 19, 18, 19], ['equal', 19, 20, 19, 20], ['equal', 20, 21, 20, 21], ['equal', 21, 22, 21, 22], ['equal', 22, 23, 22, 23], ['equal', 23, 24, 23, 24], ['equal', 24, 25, 24, 25], ['equal', 25, 26, 25, 26], ['equal', 26, 27, 26, 27], ['equal', 27, 28, 27, 28], ['equal', 28, 29, 28, 29]]
     sm = SequenceMatcher(a=a, b=b)
     self.assertEqual(sm.distance(), 16)
     self.assertEqual(sm.get_opcodes(), target_opcodes)
Example #3
0
 def test_issue4_simpler(self):
     """ Test for error reported here:
     https://github.com/belambert/edit-distance/issues/4 """
     a = ['that', 'continuous', 'sanction', ':=', '(']
     b = ['continuous', ':=', '(', 'sanction', '^']
     sm = SequenceMatcher(a=a, b=b)
     self.assertEqual(sm.distance(),  4)
     target_opcodes = [['delete', 0, 1, 0, 0], ['equal', 1, 2, 0, 1], ['delete', 2, 3, 0, 0], ['equal', 3, 4, 1, 2], ['equal', 4, 5, 2, 3], ['insert', 4, 4, 3, 4], ['insert', 4, 4, 4, 5]]
     self.assertEqual(sm.get_opcodes(), target_opcodes)
Example #4
0
 def test_unsupported(self):
     """Test if calling unimplemented methods actually generates an error."""
     a = ['a', 'b']
     b = ['a', 'b', 'd', 'c']
     sm = SequenceMatcher(a=a, b=b)
     with self.assertRaises(NotImplementedError):
         sm.find_longest_match(1, 2, 3, 4)
     with self.assertRaises(NotImplementedError):
         sm.get_grouped_opcodes()
Example #5
0
 def test_sequence_matcher2(self):
     """Test the sequence matcher."""
     a = ['a', 'b']
     b = ['a', 'b', 'd', 'c']
     sm = SequenceMatcher()
     sm.set_seq1(a)
     sm.set_seq2(b)
     self.assertTrue(sm.distance() == 2)
     sm.set_seqs(b, a)
     self.assertTrue(sm.distance() == 2)
Example #6
0
 def test_sequence_matcher2(self):
     """Test the sequence matcher."""
     a = ["a", "b"]
     b = ["a", "b", "d", "c"]
     sm = SequenceMatcher()
     sm.set_seq1(a)
     sm.set_seq2(b)
     self.assertEqual(sm.distance(), 2)
     sm.set_seqs(b, a)
     self.assertEqual(sm.distance(), 2)
Example #7
0
 def test_issue13(self):
     sm = SequenceMatcher(a="abc", b="abdc")
     self.assertEqual(
         [
             ["equal", 0, 1, 0, 1],
             ["equal", 1, 2, 1, 2],
             ["insert", 2, 2, 2, 3],
             ["equal", 2, 3, 3, 4],
         ],
         sm.get_opcodes(),
     )
Example #8
0
 def test_issue4_simpler(self):
     """ Test for error reported here:
     https://github.com/belambert/edit-distance/issues/4 """
     a = ['that', 'continuous', 'sanction', ':=', '(']
     b = ['continuous', ':=', '(', 'sanction', '^']
     sm = SequenceMatcher(a=a, b=b)
     self.assertEqual(sm.distance(), 4)
     target_opcodes = [['delete', 0, 1, 0, 0], ['equal', 1, 2, 0, 1],
                       ['delete', 2, 3, 0, 0], ['equal', 3, 4, 1, 2],
                       ['equal', 4, 5, 2, 3], ['insert', 4, 4, 3, 4],
                       ['insert', 4, 4, 4, 5]]
     self.assertEqual(sm.get_opcodes(), target_opcodes)
Example #9
0
 def test_sequence_matcher2(self):
     """Test the sequence matcher."""
     a = ['a', 'b']
     b = ['a', 'b', 'd', 'c']
     sm = SequenceMatcher()
     sm.set_seq1(a)
     sm.set_seq2(b)
     self.assertTrue(sm.distance() == 2)
     sm.set_seqs(b, a)
     self.assertTrue(sm.distance() == 2)
Example #10
0
    def get_match_score(self, prediction, target, processing="base"):
        assert processing in ["base", "structural"]

        if processing == "structural":
            prediction_clean = self.clean_structural(prediction)
            target_clean = self.clean_structural(target)
            if prediction_clean == [] and target_clean == []:
                return 1.0
        else:
            prediction_clean = self.clean_base(prediction)
            target_clean = self.clean_base(target)

        sm = SequenceMatcher(a=prediction_clean, b=target_clean)

        return sm.ratio()
Example #11
0
 def test_issue4_simpler(self):
     """Test for error reported here:
     https://github.com/belambert/edit-distance/issues/4"""
     a = ["that", "continuous", "sanction", ":=", "("]
     b = ["continuous", ":=", "(", "sanction", "^"]
     sm = SequenceMatcher(a=a, b=b)
     self.assertEqual(sm.distance(), 4)
     target_opcodes = [
         ["delete", 0, 1, 0, 0],
         ["equal", 1, 2, 0, 1],
         ["delete", 2, 3, 0, 0],
         ["equal", 3, 4, 1, 2],
         ["equal", 4, 5, 2, 3],
         ["insert", 5, 5, 3, 4],
         ["insert", 5, 5, 4, 5],
     ]
     self.assertEqual(sm.get_opcodes(), target_opcodes)
Example #12
0
 def test_sequence_matcher(self):
     """Test the sequence matcher."""
     a = ['a', 'b']
     b = ['a', 'b', 'd', 'c']
     sm = SequenceMatcher(a=a, b=b)
     opcodes = [['equal', 0, 1, 0, 1], ['equal', 1, 2, 1, 2],
                ['insert', 1, 1, 2, 3], ['insert', 1, 1, 3, 4]]
     self.assertTrue(sm.distance() == 2)
     self.assertTrue(sm.ratio() == 2 / 3)
     self.assertTrue(sm.quick_ratio() == 2 / 3)
     self.assertTrue(sm.real_quick_ratio() == 2 / 3)
     self.assertTrue(sm.distance() == 2)
     # This doesn't return anything, saves the value in the sm cache.
     self.assertTrue(not sm._compute_distance_fast())
     self.assertTrue(sm.get_opcodes() == opcodes)
     self.assertTrue(
         list(sm.get_matching_blocks()) == [[0, 0, 1], [1, 1, 1]])
Example #13
0
    def get_match_score(self, prediction, target, processing="base"):
        assert processing in ["base", "structural"]

        if processing == "structural":
            prediction_clean = self.clean_structural(prediction)
            target_clean = self.clean_structural(target)
            if prediction_clean == [] and target_clean == []:
                return 1.0
        else:
            prediction_clean = self.clean_base(prediction)
            target_clean = self.clean_base(target)

        sm = SequenceMatcher(a=prediction_clean, b=target_clean)

        # editdistance workaround on empty sequences
        if not prediction_clean and not target_clean:
            return 1
        return sm.ratio()
Example #14
0
 def test_sequence_matcher(self):
     """Test the sequence matcher."""
     a = ['a', 'b']
     b = ['a', 'b', 'd', 'c']
     sm = SequenceMatcher(a=a, b=b)
     opcodes = [['equal', 0, 1, 0, 1], ['equal', 1, 2, 1, 2], ['insert', 1, 1, 2, 3], ['insert', 1, 1, 3, 4]]
     self.assertTrue(sm.distance() == 2)
     self.assertTrue(sm.ratio() == 2 / 3)
     self.assertTrue(sm.quick_ratio() == 2 / 3)
     self.assertTrue(sm.real_quick_ratio() == 2 / 3)
     self.assertTrue(sm.distance() == 2)
     # This doesn't return anything, saves the value in the sm cache.
     self.assertTrue(not sm._compute_distance_fast())
     self.assertTrue(sm.get_opcodes() == opcodes)
     self.assertTrue(list(sm.get_matching_blocks()) == [[0, 0, 1], [1, 1, 1]])
    def align(src, tgt):
        """Corrects misalignments between the gold and predicted tokens
        which will almost almost always have different lengths due to inserted, 
        deleted, or substituted tookens in the predicted systme output."""

        sm = SequenceMatcher(
            a=list(map(lambda x: x[0], tgt)), b=list(map(lambda x: x[0], src)))
        tgt_temp, src_temp = [], []
        opcodes = sm.get_opcodes()
        for tag, i1, i2, j1, j2 in opcodes:
            # If they are equal, do nothing except lowercase them
            if tag == 'equal':
                for i in range(i1, i2):
                    tgt[i][1] = 'e'
                    tgt_temp.append(tgt[i])
                for i in range(j1, j2):
                    src[i][1] = 'e'
                    src_temp.append(src[i])
            # For insertions and deletions, put a filler of '***' on the other one, and
            # make the other all caps
            elif tag == 'delete':
                for i in range(i1, i2):
                    tgt[i][1] = 'd'
                    tgt_temp.append(tgt[i])
                for i in range(i1, i2):
                    src_temp.append(tgt[i])
            elif tag == 'insert':
                for i in range(j1, j2):
                    src[i][1] = 'i'
                    tgt_temp.append(src[i])
                for i in range(j1, j2):
                    src_temp.append(src[i])
            # More complicated logic for a substitution
            elif tag == 'replace':
                for i in range(i1, i2):
                    tgt[i][1] = 's'
                for i in range(j1, j2):
                    src[i][1] = 's'
                tgt_temp += tgt[i1:i2]
                src_temp += src[j1:j2]

        src, tgt = src_temp, tgt_temp
        return src, tgt
Example #16
0
 def test_sequence_matcher(self):
     """Test the sequence matcher."""
     a = ["a", "b"]
     b = ["a", "b", "d", "c"]
     sm = SequenceMatcher(a=a, b=b)
     opcodes = [
         ["equal", 0, 1, 0, 1],
         ["equal", 1, 2, 1, 2],
         ["insert", 2, 2, 2, 3],
         ["insert", 2, 2, 3, 4],
     ]
     self.assertEqual(sm.distance(), 2)
     self.assertEqual(sm.ratio(), 2 / 3)
     self.assertEqual(sm.quick_ratio(), 2 / 3)
     self.assertEqual(sm.real_quick_ratio(), 2 / 3)
     self.assertEqual(sm.distance(), 2)
     # This doesn't return anything, saves the value in the sm cache.
     self.assertTrue(not sm._compute_distance_fast())
     self.assertEqual(sm.get_opcodes(), opcodes)
     self.assertEqual(list(sm.get_matching_blocks()), [[0, 0, 1], [1, 1, 1]])
Example #17
0
 def test_issue4(self):
     """ Test for error reported here:
     https://github.com/belambert/edit-distance/issues/4 """
     a = [
         'that', 'continuous', 'sanction', ':=', '(', 'flee', 'U',
         'complain', ')', 'E', 'attendance', 'eye', '^', 'flowery',
         'revelation', '^', 'ridiculous', 'destination', '<EOS>', '<EOS>',
         '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>',
         '<EOS>', '<EOS>'
     ]
     b = [
         'continuous', ':=', '(', 'sanction', '^', 'flee', '^',
         'attendance', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>',
         '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>',
         '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>', '<EOS>',
         '<EOS>'
     ]
     target_opcodes = [['delete', 0, 1, 0, 0], ['equal', 1, 2, 0, 1],
                       ['delete', 2, 3, 0, 0], ['equal', 3, 4, 1, 2],
                       ['equal', 4, 5, 2, 3], ['insert', 4, 4, 3, 4],
                       ['insert', 4, 4, 4, 5], ['equal', 5, 6, 5, 6],
                       ['replace', 6, 7, 6, 7], ['replace', 7, 8, 7, 8],
                       ['replace', 8, 9, 8, 9], ['replace', 9, 10, 9, 10],
                       ['replace', 10, 11, 10, 11],
                       ['replace', 11, 12, 11, 12],
                       ['replace', 12, 13, 12, 13],
                       ['replace', 13, 14, 13, 14],
                       ['replace', 14, 15, 14, 15],
                       ['replace', 15, 16, 15, 16],
                       ['replace', 16, 17, 16, 17],
                       ['replace', 17, 18, 17, 18],
                       ['equal', 18, 19, 18, 19], ['equal', 19, 20, 19, 20],
                       ['equal', 20, 21, 20, 21], ['equal', 21, 22, 21, 22],
                       ['equal', 22, 23, 22, 23], ['equal', 23, 24, 23, 24],
                       ['equal', 24, 25, 24, 25], ['equal', 25, 26, 25, 26],
                       ['equal', 26, 27, 26, 27], ['equal', 27, 28, 27, 28],
                       ['equal', 28, 29, 28, 29]]
     sm = SequenceMatcher(a=a, b=b)
     self.assertEqual(sm.distance(), 16)
     self.assertEqual(sm.get_opcodes(), target_opcodes)
Example #18
0
def print_errors(df, n, random_state=1):
    df = df[df.token_errors > 0][['src_token', 'tgt_token', 'prediction']]
    print(len(df))
    df = df.sample(frac=n, random_state=random_state)
    for src_line, ref_line, hyp_line in zip(df['src_token'].values,
                                            df['tgt_token'].values,
                                            df['prediction'].values):
        ref = ref_line.split()
        hyp = hyp_line.split()
        sm = SequenceMatcher(a=ref, b=hyp)
        print("SRC:", src_line)
        print_diff(sm, ref, hyp)
        print()
Example #19
0
def char_error_rate(true_labels, pred_labels, decoded=False):
    cers = np.empty(len(true_labels))

    if not decoded:
        pred_labels = for_tf_or_th(pred_labels, pred_labels.swapaxes(0, 1))

    i = 0
    for true_label, pred_label in zip(true_labels, pred_labels):
        prediction = pred_label if decoded else argmax_decode(pred_label)
        ratio = SequenceMatcher(true_label, prediction).ratio()
        cers[i] = 1 - ratio
        i += 1
    return cers
Example #20
0
def wer(ref, hyp):
    ref = ref.strip().replace("'", "").split()
    hyp = hyp.strip().replace("'", "").split()
    ref = list(map(str.lower, ref))
    hyp = list(map(str.lower, hyp))

    ref_token_count = len(ref)

    sm = SequenceMatcher(a=ref, b=hyp)
    # error_count = get_error_count(sm)
    match_count = get_match_count(sm)
    # wrr = match_count / ref_token_count
    wer = 1 - match_count / ref_token_count
    return wer, match_count, ref_token_count
Example #21
0
    def _update_state(self, y_true, y_pred, mask=None):
        for i in range(len(y_true)):
            assert len(y_true[i]) == len(y_pred[i])

            # select utterance
            y_true_ = y_true[i]
            y_pred_ = y_pred[i]

            # remove padding
            y_true_ = y_true_[mask[i]]
            y_pred_ = y_pred_[mask[i]]

            # merge consequence states
            y_true_ = _merge_consequent_states(y_true_)
            y_pred_ = _merge_consequent_states(y_pred_)

            # compute edit distance
            sm = SequenceMatcher(a=y_true_, b=y_pred_)
            edit_distance = sm.distance()

            # update state
            self.edit_distance.assign_add(edit_distance)
            self.length.assign_add(len(y_true_))
Example #22
0
 def test_unsupported(self):
     """Test if calling unimplemented methods actually generates an error."""
     a = ['a', 'b']
     b = ['a', 'b', 'd', 'c']
     sm = SequenceMatcher(a=a, b=b)
     with self.assertRaises(NotImplementedError):
         sm.find_longest_match(1, 2, 3, 4)
     with self.assertRaises(NotImplementedError):
         sm.get_grouped_opcodes()
Example #23
0
def process_line_pair(ref_line, hyp_line):
    """Given a pair of strings corresponding to a reference and hypothesis,
    compute the edit distance, print if desired, and keep track of results
    in global variables."""
    # I don't believe these all need to be global.  In any case, they shouldn't be.
    global error_count
    global match_count
    global ref_token_count

    ref = ref_line.split()
    hyp = hyp_line.split()
    id_ = None

    # If the files have IDs, then split the ID off from the text
    if files_have_ids:
        remove_sentence_ids(ref, hyp)

    # Create an object to get the edit distance, and then retrieve the
    # relevant counts that we need.
    sm = SequenceMatcher(a=ref, b=hyp)
    errors = get_error_count(sm)
    matches = get_match_count(sm)
    ref_length = len(ref)

    # Increment the total counts we're tracking
    error_count += errors
    match_count += matches
    ref_token_count += ref_length

    # If we're keeping track of which words get mixed up with which others, call track_confusions
    if confusions:
        track_confusions(sm, ref, hyp)

    # If we're printing instances, do it here (in roughly the align.c format)
    if print_instances_p:
        print_instances(ref, hyp, sm, id_=id_)

    # Keep track of the individual error rates, and reference lengths, so we
    # can compute average WERs by sentence length
    lengths.append(ref_length)
    if len(ref) > 0:
        error_rate = errors * 1.0 / len(ref)
    else:
        error_rate = float("inf")
    error_rates.append(error_rate)
    wer_bins[len(ref)].append(error_rate)
Example #24
0
def print_diff(s1,
               s2,
               prefix1='REF:',
               prefix2='HYP:',
               suffix1=None,
               suffix2=None):
    """Print a readable diff between two sentences.

    This is the only place we use anything from asr-evaluation."""
    a = s1.words
    b = s2.words
    sm = SequenceMatcher(a, b)
    eval_print_diff(sm,
                    s1.words,
                    s2.words,
                    prefix1=prefix1,
                    prefix2=prefix2,
                    suffix1=suffix1,
                    suffix2=suffix2)
Example #25
0
    def compute_levenshtein(self, pred, gt, max_levenshtein=5):
        assert len(pred) == len(gt)
        total = len(gt)
        # Maintain the count of each edit-distances
        counter = Counter()
        max_chars = 0
        for ref, hyp in zip(gt, pred):
            max_chars += max(len(ref), len(hyp))
            counter[SequenceMatcher(a=ref, b=hyp).distance()] += 1
        avg_chars = max_chars / total

        # Compute accuracy for each edit-distance
        accuracies = [0] * max_levenshtein
        total_dist = 0
        for levenshtein, count in counter.items():
            if levenshtein < max_levenshtein:
                accuracies[levenshtein] = count * 100 / total
            total_dist += levenshtein * count

        avg_dist = total_dist / total
        return accuracies, avg_dist, avg_dist / avg_chars
Example #26
0
# frame-by-frame at the phoneme level
y_pred_phones = states2phones(y_pred, phones, stateList)
y_true_phones = states2phones(y_true, phones, stateList)
accuracy.reset_states()
accuracy.update_state(y_true_phones, y_pred_phones)
print('Frame-by-frame accuracy at the phoneme level: {:.2f}%'.format(
    accuracy.result().numpy() * 100))
plt.figure()
plot_confusion_matrix(y_true_phones, y_pred_phones)
plt.title('Frame-by-frame confusion matrix at the phoneme level')

# PER at the state level
N = 10000  # number of frames to consider (distance computation is expensive)
y_pred_merged = merge_consequent_states(y_pred[:N])
y_true_merged = merge_consequent_states(y_true[:N])
sm = SequenceMatcher(a=y_true_merged, b=y_pred_merged)
edit_distance = sm.distance()
print('PER at the state level: {:.2f}%'.format(edit_distance / N * 100))

# PER at the phoneme level
y_pred_merged = merge_consequent_states(y_pred_phones[:N])
y_true_merged = merge_consequent_states(y_true_phones[:N])
sm = SequenceMatcher(a=y_true_merged, b=y_pred_merged)
edit_distance = sm.distance()
print('PER at the phoneme level: {:.2f}%'.format(edit_distance / N * 100))

# posteriors for first utterance
utterance = testdata[0]
x, y = prepare_matrices([utterance],
                        K,
                        feature_type,
Example #27
0
 def test_issue4(self):
     """Test for error reported here:
     https://github.com/belambert/edit-distance/issues/4"""
     a = [
         "that",
         "continuous",
         "sanction",
         ":=",
         "(",
         "flee",
         "U",
         "complain",
         ")",
         "E",
         "attendance",
         "eye",
         "^",
         "flowery",
         "revelation",
         "^",
         "ridiculous",
         "destination",
         "<EOS>",
         "<EOS>",
         "<EOS>",
         "<EOS>",
         "<EOS>",
         "<EOS>",
         "<EOS>",
         "<EOS>",
         "<EOS>",
         "<EOS>",
         "<EOS>",
     ]  # noqa
     b = [
         "continuous",
         ":=",
         "(",
         "sanction",
         "^",
         "flee",
         "^",
         "attendance",
         "<EOS>",
         "<EOS>",
         "<EOS>",
         "<EOS>",
         "<EOS>",
         "<EOS>",
         "<EOS>",
         "<EOS>",
         "<EOS>",
         "<EOS>",
         "<EOS>",
         "<EOS>",
         "<EOS>",
         "<EOS>",
         "<EOS>",
         "<EOS>",
         "<EOS>",
         "<EOS>",
         "<EOS>",
         "<EOS>",
         "<EOS>",
     ]  # noqa
     target_opcodes = [
         ["delete", 0, 1, 0, 0],
         ["equal", 1, 2, 0, 1],
         ["delete", 2, 3, 0, 0],
         ["equal", 3, 4, 1, 2],
         ["equal", 4, 5, 2, 3],
         ["insert", 5, 5, 3, 4],
         ["insert", 5, 5, 4, 5],
         ["equal", 5, 6, 5, 6],
         ["replace", 6, 7, 6, 7],
         ["replace", 7, 8, 7, 8],
         ["replace", 8, 9, 8, 9],
         ["replace", 9, 10, 9, 10],
         ["replace", 10, 11, 10, 11],
         ["replace", 11, 12, 11, 12],
         ["replace", 12, 13, 12, 13],
         ["replace", 13, 14, 13, 14],
         ["replace", 14, 15, 14, 15],
         ["replace", 15, 16, 15, 16],
         ["replace", 16, 17, 16, 17],
         ["replace", 17, 18, 17, 18],
         ["equal", 18, 19, 18, 19],
         ["equal", 19, 20, 19, 20],
         ["equal", 20, 21, 20, 21],
         ["equal", 21, 22, 21, 22],
         ["equal", 22, 23, 22, 23],
         ["equal", 23, 24, 23, 24],
         ["equal", 24, 25, 24, 25],
         ["equal", 25, 26, 25, 26],
         ["equal", 26, 27, 26, 27],
         ["equal", 27, 28, 27, 28],
         ["equal", 28, 29, 28, 29],
     ]  # noqa
     sm = SequenceMatcher(a=a, b=b)
     self.assertEqual(sm.distance(), 16)
     self.assertEqual(sm.get_opcodes(), target_opcodes)
Example #28
0
    def calculate_metrics(
            self,
            ref,
            rec,
            ignore_punct=True,
            label=None,
            print_verbosiy=0,  #0=Nothing, 1=errors, 2=all
            exclude=None,
            query_keyword=None):
        self.counter = 0
        self.insertion_table = defaultdict(int)
        self.deletion_table = defaultdict(int)
        self.substitution_table = defaultdict(int)
        table = str.maketrans({key: None for key in string.punctuation})
        error_count = 0
        match_count = 0
        ref_token_count = 0
        sent_error_count = 0
        lengths = []
        error_rates = []
        wer_bins = defaultdict(list)
        wrr = 0.0
        wer = 0.0
        ser = 0.0
        id_ = ''

        for i, (ref_line, rec_line) in enumerate(zip(ref, rec)):
            if ignore_punct:
                _ref = ref_line.translate(table).split()
                _rec = rec_line.translate(table).split()
            else:
                _ref = ref_line.split()
                _rec = rec_line.split()

            if label is not None:
                id_ = label[i]

            if exclude is not None:
                if exclude in id_:
                    continue

            if self.case_lower:
                _ref = list(map(str.lower, _ref))
                _rec = list(map(str.lower, _rec))

            if query_keyword is not None:
                if len([i for i in query_keyword if i in _ref]) == 0:
                    continue

            sm = SequenceMatcher(a=_ref, b=_rec)
            errors = self.get_error_count(sm)
            matches = self.get_match_count(sm)
            ref_length = len(_ref)

            # Increment the total counts we're tracking
            error_count += errors
            match_count += matches
            ref_token_count += ref_length
            self.counter += 1

            if errors != 0:
                sent_error_count += 1

            self.track_confusions(sm, _ref, _rec)

            # If we're printing instances, do it here (in roughly the align.c format)
            if print_verbosiy == 2 or (print_verbosiy == 1 and errors != 0):
                self.print_all(_ref, _rec, sm, id_=id_)

            lengths.append(ref_length)
            if len(_ref) > 0:
                error_rate = errors * 1.0 / len(_ref)
            else:
                error_rate = float("inf")
            error_rates.append(error_rate)
            wer_bins[len(_ref)].append(error_rate)

        # Compute WER and WRR
        if ref_token_count > 0:
            wrr = match_count / ref_token_count
            wer = error_count / ref_token_count

        # Compute SER
        if self.counter > 0:
            ser = sent_error_count / self.counter

        print(f'\nSentence count: {self.counter}')
        print(f'WER: {wer:0.3%} ({error_count} / {ref_token_count})')
        print(f'WRR: {wrr:0.3%} ({match_count} / {ref_token_count})')
        print(f'SER: {ser:0.3%} ({sent_error_count} / {self.counter})')
        return wer, wrr, ser
Example #29
0
        # pred_one_layer_trans = transcribe(y_sample, prediction_one_layer)
        # pred_four_layer_trans = transcribe(y_sample, prediction_four_layer)
        y_transcribed = transcribe(y_sample)
        pred_one_layer_trans = transcribe(prediction_one_layer)
        pred_four_layer_trans = transcribe(prediction_four_layer)

        fig, axs = plt.subplots(3)
        axs[0].set_title("Correct output, state level merged")
        axs[0].pcolormesh(y_transcribed.T)
        axs[1].set_title(name + " 1 layer")
        axs[1].pcolormesh(pred_one_layer_trans.T)
        axs[2].set_title(name + " 4 layers")
        axs[2].pcolormesh(pred_four_layer_trans.T)
        plt.show()

        seq1 = SequenceMatcher(y_transcribed, prediction_four_layer_merged)
        distance = seq1.distance() / 324 * 100

        # TODO: Then measure the Phone Error Rate (PER),
        #  that is the length normalised edit distance between the sequence
        #  of states from the DNN and the correct transcription

        # TODO: Use SequenceMatcher from edit distance to quickly calculate PER

        # Part 4 Phenome Level edit dist
        # y_transcribed_merged = transcribe(y_sample_merged, y_sample_merged)
        # pred_one_layer_trans_merged = transcribe(y_sample_merged, prediction_one_layer_merged)
        # pred_four_layer_trans_merged = transcribe(y_sample_merged, prediction_four_layer_merged)
        y_transcribed_merged = transcribe(y_sample_merged)
        pred_one_layer_trans_merged = transcribe(prediction_one_layer_merged)
        pred_four_layer_trans_merged = transcribe(prediction_four_layer_merged)
Example #30
0
def process_line_pair(ref_line, hyp_line, case_insensitive=False, remove_empty_refs=False):
    """Given a pair of strings corresponding to a reference and hypothesis,
    compute the edit distance, print if desired, and keep track of results
    in global variables.

    Return true if the pair was counted, false if the pair was not counted due
    to an empty reference string."""
    # I don't believe these all need to be global.  In any case, they shouldn't be.
    global error_count
    global match_count
    global ref_token_count
    global sent_error_count

    # Split into tokens by whitespace
    ref = ref_line.split()
    hyp = hyp_line.split()
    id_ = None

    # If the files have IDs, then split the ID off from the text
    if files_head_ids:
        id_ = ref[0]
        ref, hyp = remove_head_id(ref, hyp)
    elif files_tail_ids:
        id_ = ref[-1]
        ref, hyp = remove_tail_id(ref, hyp)

    if case_insensitive:
        ref = list(map(str.lower, ref))
        hyp = list(map(str.lower, hyp))
    if remove_empty_refs and len(ref) == 0:
        return False

    # Create an object to get the edit distance, and then retrieve the
    # relevant counts that we need.
    sm = SequenceMatcher(a=ref, b=hyp)
    errors = get_error_count(sm)
    matches = get_match_count(sm)
    ref_length = len(ref)

    # Increment the total counts we're tracking
    error_count += errors
    match_count += matches
    ref_token_count += ref_length

    if errors != 0:
        sent_error_count += 1

    # If we're keeping track of which words get mixed up with which others, call track_confusions
    if confusions:
        track_confusions(sm, ref, hyp)

    # If we're printing instances, do it here (in roughly the align.c format)
    if print_instances_p or (print_errors_p and errors != 0):
        print_instances(ref, hyp, sm, id_=id_)

    # Keep track of the individual error rates, and reference lengths, so we
    # can compute average WERs by sentence length
    lengths.append(ref_length)
    if len(ref) > 0:
        error_rate = errors * 1.0 / len(ref)
    else:
        error_rate = float("inf")
    error_rates.append(error_rate)
    wer_bins[len(ref)].append(error_rate)
    return True