def test_two_hyps(self):
        boh = BagOfHypotheses()
        boh.add('abc', 0.5)
        boh.add('ac', 0.5)
        cn = produce_cn_from_boh(boh)

        self.assertEqual(cn, [{'a': 1.0}, {'b': 0.5, None: 0.5}, {'c': 1.0}])
    def test_lm_weight(self):
        boh = BagOfHypotheses()
        boh.add('abc', 0.0, 2.0)
        boh.add('ac', 1.0, -1.0)
        cn = produce_cn_from_boh(boh, lm_weight=2.0)

        first_prob = math.exp(0.0 + 2.0 * 2.0)
        second_prob = math.exp(1.0 + (-1.0) * 2.0)
        total_prob = first_prob + second_prob

        self.assertEqual(cn, [{
            'a': 1.0
        }, {
            'b': first_prob / total_prob,
            None: second_prob / total_prob
        }, {
            'c': 1.0
        }])
    def test_two_hyps_different_weight(self):
        boh = BagOfHypotheses()
        boh.add('abc', 0.0)
        boh.add('ac', 1.0)
        cn = produce_cn_from_boh(boh)

        first_prob = math.exp(0.0)
        second_prob = math.exp(1.0)
        total_prob = first_prob + second_prob

        self.assertEqual(cn, [{
            'a': 1.0
        }, {
            'b': first_prob / total_prob,
            None: second_prob / total_prob
        }, {
            'c': 1.0
        }])
Beispiel #4
0
def main(args):
    print(args)

    ocr_engine_chars = get_ocr_charset(args.ocr_json)

    if args.greedy:
        decoder = decoders.GreedyDecoder(ocr_engine_chars + [BLANK_SYMBOL])
    else:
        if args.lm:
            lm = construct_lm(args.lm)
        else:
            lm = None
        decoder = decoders.CTCPrefixLogRawNumpyDecoder(
            ocr_engine_chars + [BLANK_SYMBOL],
            k=args.beam_size,
            lm=lm,
            lm_scale=args.lm_scale,
            use_gpu=args.use_gpu,
            insertion_bonus=args.insertion_bonus,
        )

        if lm and args.eval:
            lm.eval()

    with open(args.input, 'rb') as f:
        complete_input = pickle.load(f)
    names = complete_input['names']
    logits = complete_input['logits']

    decodings = {}
    confidences = {}
    if args.cn_best:
        cn_decodings = {}

    t_0 = time.time()
    reporter = Reporter(nop=not args.report_eta)
    for i, (name, sparse_logits) in enumerate(zip(names, logits)):
        time_per_line = (time.time() - t_0) / (i + 1)
        nb_lines_ahead = len(names) - (i + 1)
        reporter.report(
            'Processing {} [{}/{}, {:.2f}s/line, ETA {:.2f}s]'.format(
                name, i + 1, len(names), time_per_line,
                time_per_line * nb_lines_ahead))
        dense_logits = prepare_dense_logits(sparse_logits)

        if args.greedy:
            boh = decoder(dense_logits)
        else:
            boh = decoder(dense_logits, args.model_eos)
        one_best = boh.best_hyp()
        decodings[name] = one_best
        confidences[name] = boh.confidence()

        if args.cn_best:
            cn = confusion_networks.produce_cn_from_boh(boh)
            cn_decodings[name] = confusion_networks.best_cn_path(cn)
    reporter.clear()

    save_transcriptions(args.best, decodings)

    with open(args.confidence, 'w') as f:
        for name in decodings:
            f.write('{} {:.3f}\n'.format(name, confidences[name]))

    if args.cn_best:
        save_transcriptions(args.cn_best, cn_decodings)
    def test_single_hypothesis(self):
        boh = BagOfHypotheses()
        boh.add('abc', 23.0, 2.0)
        cn = produce_cn_from_boh(boh)

        self.assertEqual(cn, [{'a': 1.0}, {'b': 1.0}, {'c': 1.0}])