예제 #1
0
def main():
    """Main method for computing Oracle WER."""
    parser = argparse.ArgumentParser()
    parser.add_argument("nbest_file", type=open_file_stream, help='A file containing n-best lists.  Read as a gzip file if filename ends with .gz')
    parser.add_argument("ref_file", type=argparse.FileType('r'))
    # parser.add_argument('--plot', '-p', default=False, action='store_true')

    args = parser.parse_args()

    print('Reading n-best lists...')    
    nbests = list(read_nbest_file(args.nbest_file))
    print('# of nbests: {}'.format(len(nbests)))
    print('Reading transcripts...')
    refs = read_transcript_table(args.ref_file)
    asr_tools.evaluation_util.REFERENCES = refs

    # This is the slow part.
    print('Running evaluation...')
    overall_eval = evaluate_nbests(nbests)
    print('Overall eval:')
    print(overall_eval)
    print()
    print('Computing oracle eval...')
    print('Oracle eval:')
    print(evaluate_nbests_oracle(nbests))

    evals = evals_by_depth(nbests)
    wers = list(map(lambda x: x.wer(), evals))

    # if args.plot:
    #     plt.plot(wers)
    #     plt.ylim(ymin=0)
    #     plt.show()

    args.nbest_file.close()
예제 #2
0
def main():
    """Main method for computing Oracle WER."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "nbest_file",
        type=open_file_stream,
        help=
        'A file containing n-best lists.  Read as a gzip file if filename ends with .gz'
    )
    parser.add_argument("ref_file", type=argparse.FileType('r'))
    parser.add_argument("output", type=argparse.FileType('w'))
    args = parser.parse_args()

    print('Reading n-best lists...')
    nbests = list(read_nbest_file(args.nbest_file))
    print('# of nbests: {}'.format(len(nbests)))
    print('Reading transcripts...')
    refs = read_transcript_table(args.ref_file)
    asr_tools.evaluation_util.REFERENCES = refs

    # This is the slow part.
    print('Running evaluation...')
    overall_eval = evaluate_nbests(nbests)

    # Write them back out to a file
    write_nbests(args.output, nbests, save_eval=True)

    args.nbest_file.close()
예제 #3
0
 def test_e2e_evaluation(self):
     # Run a full end-to-end evaluation of ASR hypotheses
     with open(self.hyp_file) as h, open(self.ref_file) as r:
         ref_table = read_transcript_table(r)
         hyps = read_transcript(h)
         evals = []
         for hyp in hyps:
             eval_ = evaluate(ref_table, hyp)
             evals.append(eval_)
         overall_eval = sum(evals[1:], evals[0])
         self.assertTrue(overall_eval.ref_len == 335)
         self.assertTrue(overall_eval.matches == 307)
         self.assertTrue(overall_eval.errs == 32)
예제 #4
0
 def test_e2e_evaluation(self):
     # Run a full end-to-end evaluation of ASR hypotheses
     with open(self.hyp_file) as h, open(self.ref_file) as r:
         ref_table = read_transcript_table(r)
         hyps = read_transcript(h)
         evals = []
         for hyp in hyps:
             eval_ = evaluate(ref_table, hyp)
             evals.append(eval_)
         overall_eval = sum(evals[1:], evals[0])
         self.assertTrue(overall_eval.ref_len == 335)
         self.assertTrue(overall_eval.matches == 307)
         self.assertTrue(overall_eval.errs == 32)
예제 #5
0
    def setUpClass(cls):
        # Load the major objects that we'll use for the tests
        with open(cls.ref_file) as f:
            cls.refs = read_transcript_table(f)
        with open(cls.ref_file) as f:
            set_global_references(f)
        with open(cls.hyp_file) as f:
            cls.hyps = read_transcript(f)
        with open(cls.nbest_file) as f:
            cls.nbests = list(read_nbest_file(f))

        cls.s1 = cls.hyps[0]
        cls.s2 = get_global_reference(cls.s1.id_)

        # Evaluate each and check their WER
        cls.e1 = evaluate(cls.refs, cls.s1)
        cls.e2 = evaluate(cls.refs, cls.s2)
예제 #6
0
    def setUpClass(cls):
        # Load the major objects that we'll use for the tests
        with open(cls.ref_file) as f:
            cls.refs = read_transcript_table(f)
        with open(cls.ref_file) as f:
            set_global_references(f)
        with open(cls.hyp_file) as f:
            cls.hyps = read_transcript(f)
        with open(cls.nbest_file) as f:
            cls.nbests = list(read_nbest_file(f))

        cls.s1 = cls.hyps[0]
        cls.s2 = get_global_reference(cls.s1.id_)

        # Evaluate each and check their WER
        cls.e1 = evaluate(cls.refs, cls.s1)
        cls.e2 = evaluate(cls.refs, cls.s2)
예제 #7
0
def main():
    """Main method for figuring out which examples from n-best lists
    are potentially improvable."""
    parser = argparse.ArgumentParser()
    parser.add_argument("nbest_file", type=open_file_stream, help='A file containing n-best lists.  Read as a gzip file if filename ends with .gz')
    parser.add_argument("ref_file", type=argparse.FileType('r'))
    args = parser.parse_args()

    nbests = list(read_nbest_file(args.nbest_file))
    refs = read_transcript_table(args.ref_file)
    asr_tools.evaluation_util.REFERENCES = refs

    overall_eval = evaluate_nbests(nbests)
    for nbest in nbests:
        if nbest.is_improveable():
            print_nbest_ref_hyp_best(nbest)
            # print_nbest(nbest)
    print(overall_eval)
예제 #8
0
def main():
    """Main method to show n-best lists, printing to console."""
    parser = argparse.ArgumentParser()
    parser.add_argument("nbest_file", type=argparse.FileType('r'))
    parser.add_argument("ref_file", nargs='?', type=argparse.FileType('r'))  # optional
    parser.add_argument("--verbose", '-v', default=True)
    args = parser.parse_args()
    colorama.init()
    nbests = list(read_nbest_file(args.nbest_file))
    if args.ref_file:
        refs = read_transcript_table(args.ref_file)
        asr_tools.evaluation_util.REFERENCES = refs
        overall_eval = evaluate_nbests(nbests)
    for nbest in nbests:
        print('NBEST:')
        print_nbest(nbest, acscore=True, lmscore=True, tscore=True, maxwords=10)
        if not monotone(nbest.sentences, comparison=operator.lt, key=Sentence.score):
            print(termcolor.colored('WARNING: Non-montonic scores', 'red', attrs=['bold']))

    if args.ref_file:
        print(overall_eval)
예제 #9
0
def main():
    """Main method for computing Oracle WER."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "nbest_file",
        type=open_file_stream,
        help=
        'A file containing n-best lists.  Read as a gzip file if filename ends with .gz'
    )
    parser.add_argument("ref_file", type=argparse.FileType('r'))
    # parser.add_argument('--plot', '-p', default=False, action='store_true')

    args = parser.parse_args()

    print('Reading n-best lists...')
    nbests = list(read_nbest_file(args.nbest_file))
    print('# of nbests: {}'.format(len(nbests)))
    print('Reading transcripts...')
    refs = read_transcript_table(args.ref_file)
    asr_tools.evaluation_util.REFERENCES = refs

    # This is the slow part.
    print('Running evaluation...')
    overall_eval = evaluate_nbests(nbests)
    print('Overall eval:')
    print(overall_eval)
    print()
    print('Computing oracle eval...')
    print('Oracle eval:')
    print(evaluate_nbests_oracle(nbests))

    evals = evals_by_depth(nbests)
    wers = list(map(lambda x: x.wer(), evals))

    # if args.plot:
    #     plt.plot(wers)
    #     plt.ylim(ymin=0)
    #     plt.show()

    args.nbest_file.close()
예제 #10
0
def main():
    """Main method for figuring out which examples from n-best lists
    are potentially improvable."""
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "nbest_file",
        type=open_file_stream,
        help=
        'A file containing n-best lists.  Read as a gzip file if filename ends with .gz'
    )
    parser.add_argument("ref_file", type=argparse.FileType('r'))
    args = parser.parse_args()

    nbests = list(read_nbest_file(args.nbest_file))
    refs = read_transcript_table(args.ref_file)
    asr_tools.evaluation_util.REFERENCES = refs

    overall_eval = evaluate_nbests(nbests)
    for nbest in nbests:
        if nbest.is_improveable():
            print_nbest_ref_hyp_best(nbest)
            # print_nbest(nbest)
    print(overall_eval)
예제 #11
0
def main():
    """Main method for computing Oracle WER."""
    parser = argparse.ArgumentParser()
    parser.add_argument("nbest_file", type=open_file_stream, help='A file containing n-best lists.  Read as a gzip file if filename ends with .gz')
    parser.add_argument("ref_file", type=argparse.FileType('r'))
    parser.add_argument("output", type=argparse.FileType('w'))
    args = parser.parse_args()

    print('Reading n-best lists...')    
    nbests = list(read_nbest_file(args.nbest_file))
    print('# of nbests: {}'.format(len(nbests)))
    print('Reading transcripts...')
    refs = read_transcript_table(args.ref_file)
    asr_tools.evaluation_util.REFERENCES = refs

    # This is the slow part.
    print('Running evaluation...')
    overall_eval = evaluate_nbests(nbests)

    # Write them back out to a file
    write_nbests(args.output, nbests, save_eval=True)

    args.nbest_file.close()
예제 #12
0
def set_global_references(ref_file):
    """Given a reference file, read it into the global table with the name
    REFERENCES."""
    global REFERENCES
    REFERENCES = read_transcript_table(ref_file)
예제 #13
0
 def test_read_transcript(self):
     # Make sure we can read a transcipt and get the right number back
     with open(self.ref_file) as f:
         refs = read_transcript_table(f)
         self.assertTrue(len(refs) == 15)
예제 #14
0
 def test_read_transcript(self):
     # Make sure we can read a transcipt and get the right number back
     with open(self.ref_file) as f:
         refs = read_transcript_table(f)
         self.assertTrue(len(refs) == 15)
예제 #15
0
def set_global_references(ref_file):
    """Given a reference file, read it into the global table with the name
    REFERENCES."""
    global REFERENCES
    REFERENCES = read_transcript_table(ref_file)
예제 #16
0
def main():
    """Main method for computing WER."""
    args = arg_parser()
    ref_table = read_transcript_table(args.ref_file)
    hyps = read_transcript(args.hyp_file)
    print(evaluate_hyps(hyps, ref_table))