def main(args): print(args) ocr_engine_chars = get_ocr_charset(args.ocr_json) if args.greedy: decoder = decoders.GreedyDecoder(ocr_engine_chars + [BLANK_SYMBOL]) else: if args.lm: lm = construct_lm(args.lm) else: lm = None decoder = decoders.CTCPrefixLogRawNumpyDecoder( ocr_engine_chars + [BLANK_SYMBOL], k=args.beam_size, lm=lm, lm_scale=args.lm_scale, use_gpu=args.use_gpu, insertion_bonus=args.insertion_bonus, ) if lm and args.eval: lm.eval() with open(args.input, 'rb') as f: complete_input = pickle.load(f) names = complete_input['names'] logits = complete_input['logits'] decodings = {} confidences = {} if args.cn_best: cn_decodings = {} t_0 = time.time() reporter = Reporter(nop=not args.report_eta) for i, (name, sparse_logits) in enumerate(zip(names, logits)): time_per_line = (time.time() - t_0) / (i + 1) nb_lines_ahead = len(names) - (i + 1) reporter.report( 'Processing {} [{}/{}, {:.2f}s/line, ETA {:.2f}s]'.format( name, i + 1, len(names), time_per_line, time_per_line * nb_lines_ahead)) dense_logits = prepare_dense_logits(sparse_logits) if args.greedy: boh = decoder(dense_logits) else: boh = decoder(dense_logits, args.model_eos) one_best = boh.best_hyp() decodings[name] = one_best confidences[name] = boh.confidence() if args.cn_best: cn = confusion_networks.produce_cn_from_boh(boh) cn_decodings[name] = confusion_networks.best_cn_path(cn) reporter.clear() save_transcriptions(args.best, decodings) with open(args.confidence, 'w') as f: for name in decodings: f.write('{} {:.3f}\n'.format(name, confidences[name])) if args.cn_best: save_transcriptions(args.cn_best, cn_decodings)
def test_epsilon_removal(self): cn = [{'a': 1.0}, {'b': 0.3, None: 0.7}, {'c': 1.0}] self.assertEqual(best_cn_path(cn), 'ac')
def test_single_hyp_cn(self): cn = [{'a': 1.0}, {'b': 1.0}, {'c': 1.0}] self.assertEqual(best_cn_path(cn), 'abc')
def test_two_hyp_cn(self): cn = [{'a': 1.0}, {'b': 0.3, 'y': 0.7}, {'c': 1.0}] self.assertEqual(best_cn_path(cn), 'ayc')
def test_empty_cn(self): cn = [] self.assertEqual(best_cn_path(cn), '')