def test(ctx, model, evaluation_files, device, pad, threads, test_set): """ Evaluate on a test set. """ if not model: raise click.UsageError('No model to evaluate given.') import numpy as np from PIL import Image from kraken.serialization import render_report from kraken.lib import models from kraken.lib.dataset import global_align, compute_confusions, generate_input_transforms logger.info('Building test set from {} line images'.format(len(test_set) + len(evaluation_files))) nn = {} for p in model: message('Loading model {}\t'.format(p), nl=False) nn[p] = models.load_any(p) message('\u2713', fg='green') test_set = list(test_set) # set number of OpenMP threads logger.debug('Set OpenMP threads to {}'.format(threads)) next(iter(nn.values())).nn.set_num_threads(threads) # merge training_files into ground_truth list if evaluation_files: test_set.extend(evaluation_files) if len(test_set) == 0: raise click.UsageError('No evaluation data was provided to the test command. Use `-e` or the `test_set` argument.') def _get_text(im): with open(os.path.splitext(im)[0] + '.gt.txt', 'r') as fp: return get_display(fp.read()) acc_list = [] for p, net in nn.items(): algn_gt: List[str] = [] algn_pred: List[str] = [] chars = 0 error = 0 message('Evaluating {}'.format(p)) logger.info('Evaluating {}'.format(p)) batch, channels, height, width = net.nn.input ts = generate_input_transforms(batch, height, width, channels, pad) with log.progressbar(test_set, label='Evaluating') as bar: for im_path in bar: i = ts(Image.open(im_path)) text = _get_text(im_path) pred = net.predict_string(i) chars += len(text) c, algn1, algn2 = global_align(text, pred) algn_gt.extend(algn1) algn_pred.extend(algn2) error += c acc_list.append((chars-error)/chars) confusions, scripts, ins, dels, subs = compute_confusions(algn_gt, algn_pred) rep = render_report(p, chars, error, confusions, scripts, ins, dels, subs) logger.info(rep) message(rep) logger.info('Average accuracy: {:0.2f}%, (stddev: {:0.2f})'.format(np.mean(acc_list) * 100, np.std(acc_list) * 100)) message('Average accuracy: {:0.2f}%, (stddev: {:0.2f})'.format(np.mean(acc_list) * 100, np.std(acc_list) * 100))
def test(ctx, model, evaluation_files, device, pad, threads, test_set): """ Evaluate on a test set. """ if not model: raise click.UsageError('No model to evaluate given.') import numpy as np from PIL import Image from kraken.serialization import render_report from kraken.lib import models from kraken.lib.dataset import global_align, compute_confusions, generate_input_transforms logger.info('Building test set from {} line images'.format( len(test_set) + len(evaluation_files))) nn = {} for p in model: message('Loading model {}\t'.format(p), nl=False) nn[p] = models.load_any(p) message('\u2713', fg='green') test_set = list(test_set) # set number of OpenMP threads logger.debug('Set OpenMP threads to {}'.format(threads)) next(iter(nn.values())).nn.set_num_threads(threads) # merge training_files into ground_truth list if evaluation_files: test_set.extend(evaluation_files) if len(test_set) == 0: raise click.UsageError( 'No evaluation data was provided to the test command. Use `-e` or the `test_set` argument.' ) def _get_text(im): with open(os.path.splitext(im)[0] + '.gt.txt', 'r') as fp: return get_display(fp.read()) acc_list = [] for p, net in nn.items(): algn_gt: List[str] = [] algn_pred: List[str] = [] chars = 0 error = 0 message('Evaluating {}'.format(p)) logger.info('Evaluating {}'.format(p)) batch, channels, height, width = net.nn.input ts = generate_input_transforms(batch, height, width, channels, pad) with log.progressbar(test_set, label='Evaluating') as bar: for im_path in bar: i = ts(Image.open(im_path)) text = _get_text(im_path) pred = net.predict_string(i) chars += len(text) c, algn1, algn2 = global_align(text, pred) algn_gt.extend(algn1) algn_pred.extend(algn2) error += c acc_list.append((chars - error) / chars) confusions, scripts, ins, dels, subs = compute_confusions( algn_gt, algn_pred) rep = render_report(p, chars, error, confusions, scripts, ins, dels, subs) logger.info(rep) message(rep) logger.info('Average accuracy: {:0.2f}%, (stddev: {:0.2f})'.format( np.mean(acc_list) * 100, np.std(acc_list) * 100)) message('Average accuracy: {:0.2f}%, (stddev: {:0.2f})'.format( np.mean(acc_list) * 100, np.std(acc_list) * 100))