def recognizer(input_image, model, pad, no_segmentation, bidi_reordering, script_ignore, mode, text_direction, segments) -> None: bounds = segments # Script detection. if bounds['script_detection']: for l in bounds['boxes']: for t in l: scripts.add(t[0]) it = rpred.mm_rpred(model, input_image, bounds, pad, bidi_reordering=bidi_reordering, script_ignore=script_ignore) else: it = rpred.rpred(model['default'], input_image, bounds, pad, bidi_reordering=bidi_reordering) preds = [] with log.progressbar(it, label='Processing', length=len(bounds['boxes'])) as bar: for pred in bar: preds.append(pred) #-------------------- print('Recognition results = {}.'.format('\n'.join(s.prediction for s in preds))) if False: with open_file(output, 'w', encoding='utf-8') as fp: print('Serializing as {} into {}'.format(mode, output)) if mode != 'text': from kraken import serialization fp.write( serialization.serialize(preds, base_image, Image.open(base_image).size, text_direction, scripts, mode)) else: fp.write('\n'.join(s.prediction for s in preds))
def line_generator(ctx, font, maxlines, encoding, normalization, renormalize, reorder, font_size, font_weight, language, max_length, strip, disable_degradation, alpha, beta, distort, distortion_sigma, legacy, output, text): """ Generates artificial text line training data. """ import errno import numpy as np from kraken import linegen from kraken.lib.util import make_printable lines: Set[str] = set() if not text: return with log.progressbar(text, label='Reading texts') as bar: for t in text: with click.open_file(t, encoding=encoding) as fp: logger.info('Reading {}'.format(t)) for l in fp: lines.add(l.rstrip('\r\n')) if normalization: lines = set([unicodedata.normalize(normalization, line) for line in lines]) if strip: lines = set([line.strip() for line in lines]) if max_length: lines = set([line for line in lines if len(line) < max_length]) logger.info('Read {} lines'.format(len(lines))) message('Read {} unique lines'.format(len(lines))) if maxlines and maxlines < len(lines): message('Sampling {} lines\t'.format(maxlines), nl=False) llist = list(lines) lines = set(llist[idx] for idx in np.random.randint(0, len(llist), maxlines)) message('\u2713', fg='green') try: os.makedirs(output) except OSError as e: if e.errno != errno.EEXIST: raise # calculate the alphabet and print it for verification purposes alphabet: Set[str] = set() for line in lines: alphabet.update(line) chars = [] combining = [] for char in sorted(alphabet): k = make_printable(char) if k != char: combining.append(k) else: chars.append(k) message('Σ (len: {})'.format(len(alphabet))) message('Symbols: {}'.format(''.join(chars))) if combining: message('Combining Characters: {}'.format(', '.join(combining))) lg = linegen.LineGenerator(font, font_size, font_weight, language) with log.progressbar(lines, label='Writing images') as bar: for idx, line in enumerate(bar): logger.info(line) try: if renormalize: im = lg.render_line(unicodedata.normalize(renormalize, line)) else: im = lg.render_line(line) except KrakenCairoSurfaceException as e: logger.info('{}: {} {}'.format(e.message, e.width, e.height)) continue if not disable_degradation and not legacy: im = linegen.degrade_line(im, alpha=alpha, beta=beta) im = linegen.distort_line(im, abs(np.random.normal(distort)), abs(np.random.normal(distortion_sigma))) elif legacy: im = linegen.ocropy_degrade(im) im.save('{}/{:06d}.png'.format(output, idx)) with open('{}/{:06d}.gt.txt'.format(output, idx), 'wb') as fp: if reorder: fp.write(get_display(line).encode('utf-8')) else: fp.write(line.encode('utf-8'))
def transcription(ctx, text_direction, scale, bw, maxcolseps, black_colseps, font, font_style, prefill, pad, lines, output, images): """ Creates transcription environments for ground truth generation. """ from PIL import Image from kraken import rpred from kraken import pageseg from kraken import transcribe from kraken import binarization from kraken.lib import models from kraken.lib.util import is_bitonal ti = transcribe.TranscriptionInterface(font, font_style) if len(images) > 1 and lines: raise click.UsageError('--lines option is incompatible with multiple image files') if prefill: logger.info('Loading model {}'.format(prefill)) message('Loading RNN', nl=False) prefill = models.load_any(prefill) message('\u2713', fg='green') with log.progressbar(images, label='Reading images') as bar: for fp in bar: logger.info('Reading {}'.format(fp.name)) im = Image.open(fp) if im.mode not in ['1', 'L', 'P', 'RGB']: logger.warning('Input {} is in {} color mode. Converting to RGB'.format(fp.name, im.mode)) im = im.convert('RGB') logger.info('Binarizing page') im_bin = binarization.nlbin(im) im_bin = im_bin.convert('1') logger.info('Segmenting page') if not lines: res = pageseg.segment(im_bin, text_direction, scale, maxcolseps, black_colseps, pad=pad) else: with open_file(lines, 'r') as fp: try: fp = cast(IO[Any], fp) res = json.load(fp) except ValueError as e: raise click.UsageError('{} invalid segmentation: {}'.format(lines, str(e))) if prefill: it = rpred.rpred(prefill, im_bin, res) preds = [] logger.info('Recognizing') for pred in it: logger.debug('{}'.format(pred.prediction)) preds.append(pred) ti.add_page(im, res, records=preds) else: ti.add_page(im, res) fp.close() logger.info('Writing transcription to {}'.format(output.name)) message('Writing output', nl=False) ti.write(output) message('\u2713', fg='green')
def extract(ctx, binarize, normalization, normalize_whitespace, reorder, rotate, output, format, transcriptions): """ Extracts image-text pairs from a transcription environment created using ``ketos transcribe``. """ import regex import base64 from io import BytesIO from PIL import Image from lxml import html, etree from kraken import binarization try: os.mkdir(output) except Exception: pass text_transforms = [] if normalization: text_transforms.append(lambda x: unicodedata.normalize(normalization, x)) if normalize_whitespace: text_transforms.append(lambda x: regex.sub('\s', ' ', x)) if reorder: text_transforms.append(get_display) idx = 0 manifest = [] with log.progressbar(transcriptions, label='Reading transcriptions') as bar: for fp in bar: logger.info('Reading {}'.format(fp.name)) doc = html.parse(fp) etree.strip_tags(doc, etree.Comment) td = doc.find(".//meta[@itemprop='text_direction']") if td is None: td = 'horizontal-lr' else: td = td.attrib['content'] im = None dest_dict = {'output': output, 'idx': 0, 'src': fp.name, 'uuid': str(uuid.uuid4())} for section in doc.xpath('//section'): img = section.xpath('.//img')[0].get('src') fd = BytesIO(base64.b64decode(img.split(',')[1])) im = Image.open(fd) if not im: logger.info('Skipping {} because image not found'.format(fp.name)) break if binarize: im = binarization.nlbin(im) for line in section.iter('li'): if line.get('contenteditable') and (not u''.join(line.itertext()).isspace() and u''.join(line.itertext())): dest_dict['idx'] = idx dest_dict['uuid'] = str(uuid.uuid4()) logger.debug('Writing line {:06d}'.format(idx)) l_img = im.crop([int(x) for x in line.get('data-bbox').split(',')]) if rotate and td.startswith('vertical'): im.rotate(90, expand=True) l_img.save(('{output}/' + format + '.png').format(**dest_dict)) manifest.append((format + '.png').format(**dest_dict)) text = u''.join(line.itertext()).strip() for func in text_transforms: text = func(text) with open(('{output}/' + format + '.gt.txt').format(**dest_dict), 'wb') as t: t.write(text.encode('utf-8')) idx += 1 logger.info('Extracted {} lines'.format(idx)) with open('{}/manifest.txt'.format(output), 'w') as fp: fp.write('\n'.join(manifest))
def test(ctx, model, evaluation_files, device, pad, threads, test_set): """ Evaluate on a test set. """ if not model: raise click.UsageError('No model to evaluate given.') import numpy as np from PIL import Image from kraken.serialization import render_report from kraken.lib import models from kraken.lib.dataset import global_align, compute_confusions, generate_input_transforms logger.info('Building test set from {} line images'.format(len(test_set) + len(evaluation_files))) nn = {} for p in model: message('Loading model {}\t'.format(p), nl=False) nn[p] = models.load_any(p) message('\u2713', fg='green') test_set = list(test_set) # set number of OpenMP threads logger.debug('Set OpenMP threads to {}'.format(threads)) next(iter(nn.values())).nn.set_num_threads(threads) # merge training_files into ground_truth list if evaluation_files: test_set.extend(evaluation_files) if len(test_set) == 0: raise click.UsageError('No evaluation data was provided to the test command. Use `-e` or the `test_set` argument.') def _get_text(im): with open(os.path.splitext(im)[0] + '.gt.txt', 'r') as fp: return get_display(fp.read()) acc_list = [] for p, net in nn.items(): algn_gt: List[str] = [] algn_pred: List[str] = [] chars = 0 error = 0 message('Evaluating {}'.format(p)) logger.info('Evaluating {}'.format(p)) batch, channels, height, width = net.nn.input ts = generate_input_transforms(batch, height, width, channels, pad) with log.progressbar(test_set, label='Evaluating') as bar: for im_path in bar: i = ts(Image.open(im_path)) text = _get_text(im_path) pred = net.predict_string(i) chars += len(text) c, algn1, algn2 = global_align(text, pred) algn_gt.extend(algn1) algn_pred.extend(algn2) error += c acc_list.append((chars-error)/chars) confusions, scripts, ins, dels, subs = compute_confusions(algn_gt, algn_pred) rep = render_report(p, chars, error, confusions, scripts, ins, dels, subs) logger.info(rep) message(rep) logger.info('Average accuracy: {:0.2f}%, (stddev: {:0.2f})'.format(np.mean(acc_list) * 100, np.std(acc_list) * 100)) message('Average accuracy: {:0.2f}%, (stddev: {:0.2f})'.format(np.mean(acc_list) * 100, np.std(acc_list) * 100))
def train(ctx, pad, output, spec, append, load, freq, quit, epochs, lag, min_delta, device, optimizer, lrate, momentum, weight_decay, schedule, partition, normalization, normalize_whitespace, codec, resize, reorder, training_files, evaluation_files, preload, threads, ground_truth): """ Trains a model from image-text pairs. """ if not load and append: raise click.BadOptionUsage('append', 'append option requires loading an existing model') if resize != 'fail' and not load: raise click.BadOptionUsage('resize', 'resize option requires loading an existing model') import re import torch import shutil import numpy as np from torch.utils.data import DataLoader from kraken.lib import models, vgsl, train from kraken.lib.util import make_printable from kraken.lib.train import EarlyStopping, EpochStopping, TrainStopper, TrainScheduler, add_1cycle from kraken.lib.codec import PytorchCodec from kraken.lib.dataset import GroundTruthDataset, generate_input_transforms logger.info('Building ground truth set from {} line images'.format(len(ground_truth) + len(training_files))) completed_epochs = 0 # load model if given. if a new model has to be created we need to do that # after data set initialization, otherwise to output size is still unknown. nn = None #hyper_fields = ['freq', 'quit', 'epochs', 'lag', 'min_delta', 'optimizer', 'lrate', 'momentum', 'weight_decay', 'schedule', 'partition', 'normalization', 'normalize_whitespace', 'reorder', 'preload', 'completed_epochs', 'output'] if load: logger.info('Loading existing model from {} '.format(load)) message('Loading existing model from {}'.format(load), nl=False) nn = vgsl.TorchVGSLModel.load_model(load) #if nn.user_metadata and load_hyper_parameters: # for param in hyper_fields: # if param in nn.user_metadata: # logger.info('Setting \'{}\' to \'{}\''.format(param, nn.user_metadata[param])) # message('Setting \'{}\' to \'{}\''.format(param, nn.user_metadata[param])) # locals()[param] = nn.user_metadata[param] message('\u2713', fg='green', nl=False) # preparse input sizes from vgsl string to seed ground truth data set # sizes and dimension ordering. if not nn: spec = spec.strip() if spec[0] != '[' or spec[-1] != ']': raise click.BadOptionUsage('spec', 'VGSL spec {} not bracketed'.format(spec)) blocks = spec[1:-1].split(' ') m = re.match(r'(\d+),(\d+),(\d+),(\d+)', blocks[0]) if not m: raise click.BadOptionUsage('spec', 'Invalid input spec {}'.format(blocks[0])) batch, height, width, channels = [int(x) for x in m.groups()] else: batch, channels, height, width = nn.input try: transforms = generate_input_transforms(batch, height, width, channels, pad) except KrakenInputException as e: raise click.BadOptionUsage('spec', str(e)) # disable automatic partition when given evaluation set explicitly if evaluation_files: partition = 1 ground_truth = list(ground_truth) # merge training_files into ground_truth list if training_files: ground_truth.extend(training_files) if len(ground_truth) == 0: raise click.UsageError('No training data was provided to the train command. Use `-t` or the `ground_truth` argument.') np.random.shuffle(ground_truth) if len(ground_truth) > 2500 and not preload: logger.info('Disabling preloading for large (>2500) training data set. Enable by setting --preload parameter') preload = False # implicit preloading enabled for small data sets if preload is None: preload = True tr_im = ground_truth[:int(len(ground_truth) * partition)] if evaluation_files: logger.debug('Using {} lines from explicit eval set'.format(len(evaluation_files))) te_im = evaluation_files else: te_im = ground_truth[int(len(ground_truth) * partition):] logger.debug('Taking {} lines from training for evaluation'.format(len(te_im))) # set multiprocessing tensor sharing strategy if 'file_system' in torch.multiprocessing.get_all_sharing_strategies(): logger.debug('Setting multiprocessing tensor sharing strategy to file_system') torch.multiprocessing.set_sharing_strategy('file_system') gt_set = GroundTruthDataset(normalization=normalization, whitespace_normalization=normalize_whitespace, reorder=reorder, im_transforms=transforms, preload=preload) with log.progressbar(tr_im, label='Building training set') as bar: for im in bar: logger.debug('Adding line {} to training set'.format(im)) try: gt_set.add(im) except FileNotFoundError as e: logger.warning('{}: {}. Skipping.'.format(e.strerror, e.filename)) except KrakenInputException as e: logger.warning(str(e)) val_set = GroundTruthDataset(normalization=normalization, whitespace_normalization=normalize_whitespace, reorder=reorder, im_transforms=transforms, preload=preload) with log.progressbar(te_im, label='Building validation set') as bar: for im in bar: logger.debug('Adding line {} to validation set'.format(im)) try: val_set.add(im) except FileNotFoundError as e: logger.warning('{}: {}. Skipping.'.format(e.strerror, e.filename)) except KrakenInputException as e: logger.warning(str(e)) logger.info('Training set {} lines, validation set {} lines, alphabet {} symbols'.format(len(gt_set._images), len(val_set._images), len(gt_set.alphabet))) alpha_diff_only_train = set(gt_set.alphabet).difference(set(val_set.alphabet)) alpha_diff_only_val = set(val_set.alphabet).difference(set(gt_set.alphabet)) if alpha_diff_only_train: logger.warning('alphabet mismatch: chars in training set only: {} (not included in accuracy test during training)'.format(alpha_diff_only_train)) if alpha_diff_only_val: logger.warning('alphabet mismatch: chars in validation set only: {} (not trained)'.format(alpha_diff_only_val)) logger.info('grapheme\tcount') for k, v in sorted(gt_set.alphabet.items(), key=lambda x: x[1], reverse=True): char = make_printable(k) if char == k: char = '\t' + char logger.info(u'{}\t{}'.format(char, v)) logger.debug('Encoding training set') # use model codec when given if append: # is already loaded nn = cast(vgsl.TorchVGSLModel, nn) gt_set.encode(codec) message('Slicing and dicing model ', nl=False) # now we can create a new model spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label()+1) logger.info('Appending {} to existing model {} after {}'.format(spec, nn.spec, append)) nn.append(append, spec) nn.add_codec(gt_set.codec) message('\u2713', fg='green') logger.info('Assembled model spec: {}'.format(nn.spec)) elif load: # is already loaded nn = cast(vgsl.TorchVGSLModel, nn) # prefer explicitly given codec over network codec if mode is 'both' codec = codec if (codec and resize == 'both') else nn.codec try: gt_set.encode(codec) except KrakenEncodeException as e: message('Network codec not compatible with training set') alpha_diff = set(gt_set.alphabet).difference(set(codec.c2l.keys())) if resize == 'fail': logger.error('Training data and model codec alphabets mismatch: {}'.format(alpha_diff)) ctx.exit(code=1) elif resize == 'add': message('Adding missing labels to network ', nl=False) logger.info('Resizing codec to include {} new code points'.format(len(alpha_diff))) codec.c2l.update({k: [v] for v, k in enumerate(alpha_diff, start=codec.max_label()+1)}) nn.add_codec(PytorchCodec(codec.c2l)) logger.info('Resizing last layer in network to {} outputs'.format(codec.max_label()+1)) nn.resize_output(codec.max_label()+1) gt_set.encode(nn.codec) message('\u2713', fg='green') elif resize == 'both': message('Fitting network exactly to training set ', nl=False) logger.info('Resizing network or given codec to {} code sequences'.format(len(gt_set.alphabet))) gt_set.encode(None) ncodec, del_labels = codec.merge(gt_set.codec) logger.info('Deleting {} output classes from network ({} retained)'.format(len(del_labels), len(codec)-len(del_labels))) gt_set.encode(ncodec) nn.resize_output(ncodec.max_label()+1, del_labels) message('\u2713', fg='green') else: raise click.BadOptionUsage('resize', 'Invalid resize value {}'.format(resize)) else: gt_set.encode(codec) logger.info('Creating new model {} with {} outputs'.format(spec, gt_set.codec.max_label()+1)) spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label()+1) nn = vgsl.TorchVGSLModel(spec) # initialize weights message('Initializing model ', nl=False) nn.init_weights() nn.add_codec(gt_set.codec) # initialize codec message('\u2713', fg='green') # half the number of data loading processes if device isn't cuda and we haven't enabled preloading if device == 'cpu' and not preload: loader_threads = threads // 2 else: loader_threads = threads train_loader = DataLoader(gt_set, batch_size=1, shuffle=True, num_workers=loader_threads, pin_memory=True) threads -= loader_threads # don't encode validation set as the alphabets may not match causing encoding failures val_set.training_set = list(zip(val_set._images, val_set._gt)) logger.debug('Constructing {} optimizer (lr: {}, momentum: {})'.format(optimizer, lrate, momentum)) # set mode to trainindg nn.train() # set number of OpenMP threads logger.debug('Set OpenMP threads to {}'.format(threads)) nn.set_num_threads(threads) logger.debug('Moving model to device {}'.format(device)) optim = getattr(torch.optim, optimizer)(nn.nn.parameters(), lr=0) if 'accuracy' not in nn.user_metadata: nn.user_metadata['accuracy'] = [] tr_it = TrainScheduler(optim) if schedule == '1cycle': add_1cycle(tr_it, int(len(gt_set) * epochs), lrate, momentum, momentum - 0.10, weight_decay) else: # constant learning rate scheduler tr_it.add_phase(1, (lrate, lrate), (momentum, momentum), weight_decay, train.annealing_const) if quit == 'early': st_it = EarlyStopping(min_delta, lag) elif quit == 'dumb': st_it = EpochStopping(epochs - completed_epochs) else: raise click.BadOptionUsage('quit', 'Invalid training interruption scheme {}'.format(quit)) #for param in hyper_fields: # logger.debug('Setting \'{}\' to \'{}\' in model metadata'.format(param, locals()[param])) # nn.user_metadata[param] = locals()[param] trainer = train.KrakenTrainer(model=nn, optimizer=optim, device=device, filename_prefix=output, event_frequency=freq, train_set=train_loader, val_set=val_set, stopper=st_it) trainer.add_lr_scheduler(tr_it) with log.progressbar(label='stage {}/{}'.format(1, trainer.stopper.epochs if trainer.stopper.epochs > 0 else '∞'), length=trainer.event_it, show_pos=True) as bar: def _draw_progressbar(): bar.update(1) def _print_eval(epoch, accuracy, chars, error): message('Accuracy report ({}) {:0.4f} {} {}'.format(epoch, accuracy, chars, error)) # reset progress bar bar.label = 'stage {}/{}'.format(epoch+1, trainer.stopper.epochs if trainer.stopper.epochs > 0 else '∞') bar.pos = 0 bar.finished = False trainer.run(_print_eval, _draw_progressbar) if quit == 'early': message('Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'.format(output, trainer.stopper.best_epoch, trainer.stopper.best_loss)) logger.info('Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'.format(output, trainer.stopper.best_epoch, trainer.stopper.best_loss)) shutil.copy('{}_{}.mlmodel'.format(output, trainer.stopper.best_epoch), '{}_best.mlmodel'.format(output))
def process_pipeline(subcommands, input, batch_input, suffix, verbose, format_type, pdf_format, **args): """ Helper function calling the partials returned by each subcommand and placing their respective outputs in temporary files. """ import glob import uuid import tempfile input = list(input) # expand batch inputs if batch_input and suffix: for batch_expr in batch_input: for in_file in glob.glob(batch_expr, recursive=True): input.append( (in_file, '{}{}'.format(os.path.splitext(in_file)[0], suffix))) # parse pdfs if format_type == 'pdf': import pyvips if not batch_input: logger.warning( 'PDF inputs not added with batch option. Manual output filename will be ignored and `-o` utilized.' ) new_input = [] num_pages = 0 for (fpath, _) in input: doc = pyvips.Image.new_from_file(fpath, dpi=300, n=-1, access="sequential") if 'n-pages' in doc.get_fields(): num_pages += doc.get('n-pages') with log.progressbar(length=num_pages, label='Extracting PDF pages') as bar: for (fpath, _) in input: try: doc = pyvips.Image.new_from_file(fpath, dpi=300, n=-1, access="sequential") if 'n-pages' not in doc.get_fields(): logger.warning( '{fpath} does not contain pages. Skipping.') continue n_pages = doc.get('n-pages') dest_dict = {'idx': -1, 'src': fpath, 'uuid': None} for i in range(0, n_pages): dest_dict['idx'] += 1 dest_dict['uuid'] = str(uuid.uuid4()) fd, filename = tempfile.mkstemp(suffix='.png') os.close(fd) doc = pyvips.Image.new_from_file(fpath, dpi=300, page=i, access="sequential") logger.info( f'Saving temporary image {fpath}:{dest_dict["idx"]} to {filename}' ) doc.write_to_file(filename) new_input.append( (filename, pdf_format.format(**dest_dict) + suffix)) bar.update(1) except pyvips.error.Error: logger.warning(f'{fpath} is not a PDF file. Skipping.') input = new_input ctx = click.get_current_context() for io_pair in input: ctx.meta['first_process'] = True ctx.meta['last_process'] = False ctx.meta['orig_file'] = io_pair[0] if 'base_image' in ctx.meta: del ctx.meta['base_image'] try: tmps = [tempfile.mkstemp() for cmd in subcommands[1:]] for tmp in tmps: os.close(tmp[0]) fc = [io_pair[0]] + [tmp[1] for tmp in tmps] + [io_pair[1]] for idx, (task, input, output) in enumerate(zip(subcommands, fc, fc[1:])): if len(fc) - 2 == idx: ctx.meta['last_process'] = True task(input=input, output=output) except Exception as e: logger.error(f'Failed processing {io_pair[0]}: {str(e)}') if ctx.meta['raise_failed'] is True: raise finally: for f in fc[1:-1]: os.unlink(f) # clean up temporary PDF image files if format_type == 'pdf': logger.debug(f'unlinking {fc[0]}') os.unlink(fc[0])
def recognizer(model, pad, no_segmentation, bidi_reordering, script_ignore, input, output) -> None: import json from kraken import rpred ctx = click.get_current_context() bounds = None if 'base_image' not in ctx.meta: ctx.meta['base_image'] = input if ctx.meta['first_process']: if ctx.meta['input_format_type'] != 'image': doc = get_input_parser(ctx.meta['input_format_type'])(input) ctx.meta['base_image'] = doc['image'] doc['text_direction'] = 'horizontal-lr' bounds = doc try: im = Image.open(ctx.meta['base_image']) except IOError as e: raise click.BadParameter(str(e)) if not bounds and ctx.meta['base_image'] != input: with open_file(input, 'r') as fp: try: fp = cast(IO[Any], fp) bounds = json.load(fp) except ValueError as e: raise click.UsageError( f'{input} invalid segmentation: {str(e)}') elif not bounds: if no_segmentation: bounds = { 'script_detection': False, 'text_direction': 'horizontal-lr', 'boxes': [(0, 0) + im.size] } else: raise click.UsageError( 'No line segmentation given. Add one with the input or run `segment` first.' ) elif no_segmentation: logger.warning( 'no_segmentation mode enabled but segmentation defined. Ignoring --no-segmentation option.' ) scripts = set() # script detection if 'script_detection' in bounds and bounds['script_detection']: it = rpred.mm_rpred(model, im, bounds, pad, bidi_reordering=bidi_reordering, script_ignore=script_ignore) else: it = rpred.rpred(model['default'], im, bounds, pad, bidi_reordering=bidi_reordering) preds = [] with log.progressbar(it, label='Processing') as bar: for pred in bar: preds.append(pred) ctx = click.get_current_context() with open_file(output, 'w', encoding='utf-8') as fp: fp = cast(IO[Any], fp) message(f'Writing recognition results for {ctx.meta["orig_file"]}\t', nl=False) logger.info('Serializing as {} into {}'.format(ctx.meta['output_mode'], output)) if ctx.meta['output_mode'] != 'native': from kraken import serialization fp.write( serialization.serialize( preds, ctx.meta['base_image'], Image.open(ctx.meta['base_image']).size, ctx.meta['text_direction'], scripts, bounds['regions'] if 'regions' in bounds else None, ctx.meta['output_mode'])) else: fp.write('\n'.join(s.prediction for s in preds)) message('\u2713', fg='green')
def transcription(ctx, text_direction, scale, bw, maxcolseps, black_colseps, font, font_style, prefill, pad, lines, output, images): """ Creates transcription environments for ground truth generation. """ from PIL import Image from kraken import rpred from kraken import pageseg from kraken import transcribe from kraken import binarization from kraken.lib import models from kraken.lib.util import is_bitonal ti = transcribe.TranscriptionInterface(font, font_style) if len(images) > 1 and lines: raise click.UsageError( '--lines option is incompatible with multiple image files') if prefill: logger.info('Loading model {}'.format(prefill)) message('Loading RNN', nl=False) prefill = models.load_any(prefill) message('\u2713', fg='green') with log.progressbar(images, label='Reading images') as bar: for fp in bar: logger.info('Reading {}'.format(fp.name)) im = Image.open(fp) if im.mode not in ['1', 'L', 'P', 'RGB']: logger.warning( 'Input {} is in {} color mode. Converting to RGB'.format( fp.name, im.mode)) im = im.convert('RGB') logger.info('Binarizing page') im_bin = binarization.nlbin(im) im_bin = im_bin.convert('1') logger.info('Segmenting page') if not lines: res = pageseg.segment(im_bin, text_direction, scale, maxcolseps, black_colseps, pad=pad) else: with open_file(lines, 'r') as fp: try: fp = cast(IO[Any], fp) res = json.load(fp) except ValueError as e: raise click.UsageError( '{} invalid segmentation: {}'.format( lines, str(e))) if prefill: it = rpred.rpred(prefill, im_bin, res) preds = [] logger.info('Recognizing') for pred in it: logger.debug('{}'.format(pred.prediction)) preds.append(pred) ti.add_page(im, res, records=preds) else: ti.add_page(im, res) fp.close() logger.info('Writing transcription to {}'.format(output.name)) message('Writing output', nl=False) ti.write(output) message('\u2713', fg='green')
def extract(ctx, binarize, normalization, normalize_whitespace, reorder, rotate, output, format, transcriptions): """ Extracts image-text pairs from a transcription environment created using ``ketos transcribe``. """ import regex import base64 from io import BytesIO from PIL import Image from lxml import html, etree from kraken import binarization try: os.mkdir(output) except Exception: pass text_transforms = [] if normalization: text_transforms.append( lambda x: unicodedata.normalize(normalization, x)) if normalize_whitespace: text_transforms.append(lambda x: regex.sub('\s', ' ', x)) if reorder: text_transforms.append(get_display) idx = 0 manifest = [] with log.progressbar(transcriptions, label='Reading transcriptions') as bar: for fp in bar: logger.info('Reading {}'.format(fp.name)) doc = html.parse(fp) etree.strip_tags(doc, etree.Comment) td = doc.find(".//meta[@itemprop='text_direction']") if td is None: td = 'horizontal-lr' else: td = td.attrib['content'] im = None dest_dict = { 'output': output, 'idx': 0, 'src': fp.name, 'uuid': str(uuid.uuid4()) } for section in doc.xpath('//section'): img = section.xpath('.//img')[0].get('src') fd = BytesIO(base64.b64decode(img.split(',')[1])) im = Image.open(fd) if not im: logger.info('Skipping {} because image not found'.format( fp.name)) break if binarize: im = binarization.nlbin(im) for line in section.iter('li'): if line.get('contenteditable') and ( not u''.join(line.itertext()).isspace() and u''.join(line.itertext())): dest_dict['idx'] = idx dest_dict['uuid'] = str(uuid.uuid4()) logger.debug('Writing line {:06d}'.format(idx)) l_img = im.crop( [int(x) for x in line.get('data-bbox').split(',')]) if rotate and td.startswith('vertical'): im.rotate(90, expand=True) l_img.save(('{output}/' + format + '.png').format(**dest_dict)) manifest.append((format + '.png').format(**dest_dict)) text = u''.join(line.itertext()).strip() for func in text_transforms: text = func(text) with open(('{output}/' + format + '.gt.txt').format(**dest_dict), 'wb') as t: t.write(text.encode('utf-8')) idx += 1 logger.info('Extracted {} lines'.format(idx)) with open('{}/manifest.txt'.format(output), 'w') as fp: fp.write('\n'.join(manifest))
def test(ctx, model, evaluation_files, device, pad, threads, test_set): """ Evaluate on a test set. """ if not model: raise click.UsageError('No model to evaluate given.') import numpy as np from PIL import Image from kraken.serialization import render_report from kraken.lib import models from kraken.lib.dataset import global_align, compute_confusions, generate_input_transforms logger.info('Building test set from {} line images'.format( len(test_set) + len(evaluation_files))) nn = {} for p in model: message('Loading model {}\t'.format(p), nl=False) nn[p] = models.load_any(p) message('\u2713', fg='green') test_set = list(test_set) # set number of OpenMP threads logger.debug('Set OpenMP threads to {}'.format(threads)) next(iter(nn.values())).nn.set_num_threads(threads) # merge training_files into ground_truth list if evaluation_files: test_set.extend(evaluation_files) if len(test_set) == 0: raise click.UsageError( 'No evaluation data was provided to the test command. Use `-e` or the `test_set` argument.' ) def _get_text(im): with open(os.path.splitext(im)[0] + '.gt.txt', 'r') as fp: return get_display(fp.read()) acc_list = [] for p, net in nn.items(): algn_gt: List[str] = [] algn_pred: List[str] = [] chars = 0 error = 0 message('Evaluating {}'.format(p)) logger.info('Evaluating {}'.format(p)) batch, channels, height, width = net.nn.input ts = generate_input_transforms(batch, height, width, channels, pad) with log.progressbar(test_set, label='Evaluating') as bar: for im_path in bar: i = ts(Image.open(im_path)) text = _get_text(im_path) pred = net.predict_string(i) chars += len(text) c, algn1, algn2 = global_align(text, pred) algn_gt.extend(algn1) algn_pred.extend(algn2) error += c acc_list.append((chars - error) / chars) confusions, scripts, ins, dels, subs = compute_confusions( algn_gt, algn_pred) rep = render_report(p, chars, error, confusions, scripts, ins, dels, subs) logger.info(rep) message(rep) logger.info('Average accuracy: {:0.2f}%, (stddev: {:0.2f})'.format( np.mean(acc_list) * 100, np.std(acc_list) * 100)) message('Average accuracy: {:0.2f}%, (stddev: {:0.2f})'.format( np.mean(acc_list) * 100, np.std(acc_list) * 100))
def train(ctx, pad, output, spec, append, load, freq, quit, epochs, lag, min_delta, device, optimizer, lrate, momentum, weight_decay, schedule, partition, normalization, normalize_whitespace, codec, resize, reorder, training_files, evaluation_files, preload, threads, ground_truth): """ Trains a model from image-text pairs. """ if not load and append: raise click.BadOptionUsage( 'append', 'append option requires loading an existing model') if resize != 'fail' and not load: raise click.BadOptionUsage( 'resize', 'resize option requires loading an existing model') import re import torch import shutil import numpy as np from torch.utils.data import DataLoader from kraken.lib import models, vgsl, train from kraken.lib.util import make_printable from kraken.lib.train import EarlyStopping, EpochStopping, TrainStopper, TrainScheduler, add_1cycle from kraken.lib.codec import PytorchCodec from kraken.lib.dataset import GroundTruthDataset, generate_input_transforms logger.info('Building ground truth set from {} line images'.format( len(ground_truth) + len(training_files))) completed_epochs = 0 # load model if given. if a new model has to be created we need to do that # after data set initialization, otherwise to output size is still unknown. nn = None #hyper_fields = ['freq', 'quit', 'epochs', 'lag', 'min_delta', 'optimizer', 'lrate', 'momentum', 'weight_decay', 'schedule', 'partition', 'normalization', 'normalize_whitespace', 'reorder', 'preload', 'completed_epochs', 'output'] if load: logger.info('Loading existing model from {} '.format(load)) message('Loading existing model from {}'.format(load), nl=False) nn = vgsl.TorchVGSLModel.load_model(load) #if nn.user_metadata and load_hyper_parameters: # for param in hyper_fields: # if param in nn.user_metadata: # logger.info('Setting \'{}\' to \'{}\''.format(param, nn.user_metadata[param])) # message('Setting \'{}\' to \'{}\''.format(param, nn.user_metadata[param])) # locals()[param] = nn.user_metadata[param] message('\u2713', fg='green', nl=False) # preparse input sizes from vgsl string to seed ground truth data set # sizes and dimension ordering. if not nn: spec = spec.strip() if spec[0] != '[' or spec[-1] != ']': raise click.BadOptionUsage( 'spec', 'VGSL spec {} not bracketed'.format(spec)) blocks = spec[1:-1].split(' ') m = re.match(r'(\d+),(\d+),(\d+),(\d+)', blocks[0]) if not m: raise click.BadOptionUsage( 'spec', 'Invalid input spec {}'.format(blocks[0])) batch, height, width, channels = [int(x) for x in m.groups()] else: batch, channels, height, width = nn.input try: transforms = generate_input_transforms(batch, height, width, channels, pad) except KrakenInputException as e: raise click.BadOptionUsage('spec', str(e)) # disable automatic partition when given evaluation set explicitly if evaluation_files: partition = 1 ground_truth = list(ground_truth) # merge training_files into ground_truth list if training_files: ground_truth.extend(training_files) if len(ground_truth) == 0: raise click.UsageError( 'No training data was provided to the train command. Use `-t` or the `ground_truth` argument.' ) np.random.shuffle(ground_truth) if len(ground_truth) > 2500 and not preload: logger.info( 'Disabling preloading for large (>2500) training data set. Enable by setting --preload parameter' ) preload = False # implicit preloading enabled for small data sets if preload is None: preload = True tr_im = ground_truth[:int(len(ground_truth) * partition)] if evaluation_files: logger.debug('Using {} lines from explicit eval set'.format( len(evaluation_files))) te_im = evaluation_files else: te_im = ground_truth[int(len(ground_truth) * partition):] logger.debug('Taking {} lines from training for evaluation'.format( len(te_im))) # set multiprocessing tensor sharing strategy if 'file_system' in torch.multiprocessing.get_all_sharing_strategies(): logger.debug( 'Setting multiprocessing tensor sharing strategy to file_system') torch.multiprocessing.set_sharing_strategy('file_system') gt_set = GroundTruthDataset(normalization=normalization, whitespace_normalization=normalize_whitespace, reorder=reorder, im_transforms=transforms, preload=preload) with log.progressbar(tr_im, label='Building training set') as bar: for im in bar: logger.debug('Adding line {} to training set'.format(im)) try: gt_set.add(im) except FileNotFoundError as e: logger.warning('{}: {}. Skipping.'.format( e.strerror, e.filename)) except KrakenInputException as e: logger.warning(str(e)) val_set = GroundTruthDataset(normalization=normalization, whitespace_normalization=normalize_whitespace, reorder=reorder, im_transforms=transforms, preload=preload) with log.progressbar(te_im, label='Building validation set') as bar: for im in bar: logger.debug('Adding line {} to validation set'.format(im)) try: val_set.add(im) except FileNotFoundError as e: logger.warning('{}: {}. Skipping.'.format( e.strerror, e.filename)) except KrakenInputException as e: logger.warning(str(e)) logger.info( 'Training set {} lines, validation set {} lines, alphabet {} symbols'. format(len(gt_set._images), len(val_set._images), len(gt_set.alphabet))) alpha_diff = set(gt_set.alphabet).symmetric_difference( set(val_set.alphabet)) if alpha_diff: logger.warn('alphabet mismatch {}'.format(alpha_diff)) logger.info('grapheme\tcount') for k, v in sorted(gt_set.alphabet.items(), key=lambda x: x[1], reverse=True): char = make_printable(k) if char == k: char = '\t' + char logger.info(u'{}\t{}'.format(char, v)) logger.debug('Encoding training set') # use model codec when given if append: # is already loaded nn = cast(vgsl.TorchVGSLModel, nn) gt_set.encode(codec) message('Slicing and dicing model ', nl=False) # now we can create a new model spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label() + 1) logger.info('Appending {} to existing model {} after {}'.format( spec, nn.spec, append)) nn.append(append, spec) nn.add_codec(gt_set.codec) message('\u2713', fg='green') logger.info('Assembled model spec: {}'.format(nn.spec)) elif load: # is already loaded nn = cast(vgsl.TorchVGSLModel, nn) # prefer explicitly given codec over network codec if mode is 'both' codec = codec if (codec and resize == 'both') else nn.codec try: gt_set.encode(codec) except KrakenEncodeException as e: message('Network codec not compatible with training set') alpha_diff = set(gt_set.alphabet).difference(set(codec.c2l.keys())) if resize == 'fail': logger.error( 'Training data and model codec alphabets mismatch: {}'. format(alpha_diff)) ctx.exit(code=1) elif resize == 'add': message('Adding missing labels to network ', nl=False) logger.info( 'Resizing codec to include {} new code points'.format( len(alpha_diff))) codec.c2l.update({ k: [v] for v, k in enumerate(alpha_diff, start=codec.max_label() + 1) }) nn.add_codec(PytorchCodec(codec.c2l)) logger.info( 'Resizing last layer in network to {} outputs'.format( codec.max_label() + 1)) nn.resize_output(codec.max_label() + 1) gt_set.encode(nn.codec) message('\u2713', fg='green') elif resize == 'both': message('Fitting network exactly to training set ', nl=False) logger.info( 'Resizing network or given codec to {} code sequences'. format(len(gt_set.alphabet))) gt_set.encode(None) ncodec, del_labels = codec.merge(gt_set.codec) logger.info( 'Deleting {} output classes from network ({} retained)'. format(len(del_labels), len(codec) - len(del_labels))) gt_set.encode(ncodec) nn.resize_output(ncodec.max_label() + 1, del_labels) message('\u2713', fg='green') else: raise click.BadOptionUsage( 'resize', 'Invalid resize value {}'.format(resize)) else: gt_set.encode(codec) logger.info('Creating new model {} with {} outputs'.format( spec, gt_set.codec.max_label() + 1)) spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label() + 1) nn = vgsl.TorchVGSLModel(spec) # initialize weights message('Initializing model ', nl=False) nn.init_weights() nn.add_codec(gt_set.codec) # initialize codec message('\u2713', fg='green') # half the number of data loading processes if device isn't cuda and we haven't enabled preloading if device == 'cpu' and not preload: loader_threads = threads // 2 else: loader_threads = threads train_loader = DataLoader(gt_set, batch_size=1, shuffle=True, num_workers=loader_threads, pin_memory=True) threads -= loader_threads # don't encode validation set as the alphabets may not match causing encoding failures val_set.training_set = list(zip(val_set._images, val_set._gt)) logger.debug('Constructing {} optimizer (lr: {}, momentum: {})'.format( optimizer, lrate, momentum)) # set mode to trainindg nn.train() # set number of OpenMP threads logger.debug('Set OpenMP threads to {}'.format(threads)) nn.set_num_threads(threads) logger.debug('Moving model to device {}'.format(device)) optim = getattr(torch.optim, optimizer)(nn.nn.parameters(), lr=0) if 'accuracy' not in nn.user_metadata: nn.user_metadata['accuracy'] = [] tr_it = TrainScheduler(optim) if schedule == '1cycle': add_1cycle(tr_it, int(len(gt_set) * epochs), lrate, momentum, momentum - 0.10, weight_decay) else: # constant learning rate scheduler tr_it.add_phase(1, (lrate, lrate), (momentum, momentum), weight_decay, train.annealing_const) if quit == 'early': st_it = EarlyStopping(min_delta, lag) elif quit == 'dumb': st_it = EpochStopping(epochs - completed_epochs) else: raise click.BadOptionUsage( 'quit', 'Invalid training interruption scheme {}'.format(quit)) #for param in hyper_fields: # logger.debug('Setting \'{}\' to \'{}\' in model metadata'.format(param, locals()[param])) # nn.user_metadata[param] = locals()[param] trainer = train.KrakenTrainer(model=nn, optimizer=optim, device=device, filename_prefix=output, event_frequency=freq, train_set=train_loader, val_set=val_set, stopper=st_it) trainer.add_lr_scheduler(tr_it) with log.progressbar(label='stage {}/{}'.format( 1, trainer.stopper.epochs if trainer.stopper.epochs > 0 else '∞'), length=trainer.event_it, show_pos=True) as bar: def _draw_progressbar(): bar.update(1) def _print_eval(epoch, accuracy, chars, error): message('Accuracy report ({}) {:0.4f} {} {}'.format( epoch, accuracy, chars, error)) # reset progress bar bar.label = 'stage {}/{}'.format( epoch + 1, trainer.stopper.epochs if trainer.stopper.epochs > 0 else '∞') bar.pos = 0 bar.finished = False trainer.run(_print_eval, _draw_progressbar) if quit == 'early': message('Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'. format(output, trainer.stopper.best_epoch, trainer.stopper.best_loss)) logger.info( 'Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'. format(output, trainer.stopper.best_epoch, trainer.stopper.best_loss)) shutil.copy('{}_{}.mlmodel'.format(output, trainer.stopper.best_epoch), '{}_best.mlmodel'.format(output))
def line_generator(ctx, font, maxlines, encoding, normalization, renormalize, reorder, font_size, font_weight, language, max_length, strip, disable_degradation, alpha, beta, distort, distortion_sigma, legacy, output, text): """ Generates artificial text line training data. """ import errno import numpy as np from kraken import linegen from kraken.lib.util import make_printable lines: Set[str] = set() if not text: return with log.progressbar(text, label='Reading texts') as bar: for t in text: with click.open_file(t, encoding=encoding) as fp: logger.info('Reading {}'.format(t)) for l in fp: lines.add(l.rstrip('\r\n')) if normalization: lines = set( [unicodedata.normalize(normalization, line) for line in lines]) if strip: lines = set([line.strip() for line in lines]) if max_length: lines = set([line for line in lines if len(line) < max_length]) logger.info('Read {} lines'.format(len(lines))) message('Read {} unique lines'.format(len(lines))) if maxlines and maxlines < len(lines): message('Sampling {} lines\t'.format(maxlines), nl=False) llist = list(lines) lines = set(llist[idx] for idx in np.random.randint(0, len(llist), maxlines)) message('\u2713', fg='green') try: os.makedirs(output) except OSError as e: if e.errno != errno.EEXIST: raise # calculate the alphabet and print it for verification purposes alphabet: Set[str] = set() for line in lines: alphabet.update(line) chars = [] combining = [] for char in sorted(alphabet): k = make_printable(char) if k != char: combining.append(k) else: chars.append(k) message('Σ (len: {})'.format(len(alphabet))) message('Symbols: {}'.format(''.join(chars))) if combining: message('Combining Characters: {}'.format(', '.join(combining))) lg = linegen.LineGenerator(font, font_size, font_weight, language) with log.progressbar(lines, label='Writing images') as bar: for idx, line in enumerate(bar): logger.info(line) try: if renormalize: im = lg.render_line( unicodedata.normalize(renormalize, line)) else: im = lg.render_line(line) except KrakenCairoSurfaceException as e: logger.info('{}: {} {}'.format(e.message, e.width, e.height)) continue if not disable_degradation and not legacy: im = linegen.degrade_line(im, alpha=alpha, beta=beta) im = linegen.distort_line( im, abs(np.random.normal(distort)), abs(np.random.normal(distortion_sigma))) elif legacy: im = linegen.ocropy_degrade(im) im.save('{}/{:06d}.png'.format(output, idx)) with open('{}/{:06d}.gt.txt'.format(output, idx), 'wb') as fp: if reorder: fp.write(get_display(line).encode('utf-8')) else: fp.write(line.encode('utf-8'))
def recognizer(model, pad, no_segmentation, bidi_reordering, script_ignore, base_image, input, output, lines) -> None: import json import tempfile from kraken import rpred try: im = Image.open(base_image) except IOError as e: raise click.BadParameter(str(e)) ctx = click.get_current_context() # input may either be output from the segmenter then it is a JSON file or # be an image file when running the OCR subcommand alone. might still come # from some other subcommand though. scripts = set() if not lines and base_image != input: lines = input if not lines: if no_segmentation: lines = tempfile.NamedTemporaryFile(mode='w', delete=False) logger.info( 'Running in no_segmentation mode. Creating temporary segmentation {}.' .format(lines.name)) json.dump( { 'script_detection': False, 'text_direction': 'horizontal-lr', 'boxes': [(0, 0) + im.size] }, lines) lines.close() lines = lines.name else: raise click.UsageError( 'No line segmentation given. Add one with `-l` or run `segment` first.' ) elif no_segmentation: logger.warning( 'no_segmentation mode enabled but segmentation defined. Ignoring --no-segmentation option.' ) with open_file(lines, 'r') as fp: try: fp = cast(IO[Any], fp) bounds = json.load(fp) except ValueError as e: raise click.UsageError('{} invalid segmentation: {}'.format( lines, str(e))) # script detection if bounds['script_detection']: for l in bounds['boxes']: for t in l: scripts.add(t[0]) it = rpred.mm_rpred(model, im, bounds, pad, bidi_reordering=bidi_reordering, script_ignore=script_ignore) else: it = rpred.rpred(model['default'], im, bounds, pad, bidi_reordering=bidi_reordering) if not lines and no_segmentation: logger.debug('Removing temporary segmentation file.') os.unlink(lines.name) preds = [] with log.progressbar(it, label='Processing', length=len(bounds['boxes'])) as bar: for pred in bar: preds.append(pred) ctx = click.get_current_context() with open_file(output, 'w', encoding='utf-8') as fp: fp = cast(IO[Any], fp) message('Writing recognition results for {}\t'.format(base_image), nl=False) logger.info('Serializing as {} into {}'.format(ctx.meta['mode'], output)) if ctx.meta['mode'] != 'text': from kraken import serialization fp.write( serialization.serialize(preds, base_image, Image.open(base_image).size, ctx.meta['text_direction'], scripts, ctx.meta['mode'])) else: fp.write('\n'.join(s.prediction for s in preds)) message('\u2713', fg='green')
def train(ctx, pad, output, spec, append, load, savefreq, report, quit, epochs, lag, min_delta, device, optimizer, lrate, momentum, weight_decay, schedule, partition, normalization, codec, resize, reorder, training_files, evaluation_files, preload, threads, ground_truth): """ Trains a model from image-text pairs. """ if not load and append: raise click.BadOptionUsage( 'append', 'append option requires loading an existing model') if resize != 'fail' and not load: raise click.BadOptionUsage( 'resize', 'resize option requires loading an existing model') import re import torch import shutil import numpy as np from torch.utils.data import DataLoader from kraken.lib import models, vgsl, train from kraken.lib.util import make_printable from kraken.lib.train import EarlyStopping, EpochStopping, TrainStopper, TrainScheduler, add_1cycle from kraken.lib.codec import PytorchCodec from kraken.lib.dataset import GroundTruthDataset, compute_error, generate_input_transforms logger.info('Building ground truth set from {} line images'.format( len(ground_truth) + len(training_files))) # load model if given. if a new model has to be created we need to do that # after data set initialization, otherwise to output size is still unknown. nn = None if load: logger.info('Loading existing model from {} '.format(load)) message('Loading model {}'.format(load), nl=False) nn = vgsl.TorchVGSLModel.load_model(load) message('\u2713', fg='green', nl=False) # preparse input sizes from vgsl string to seed ground truth data set # sizes and dimension ordering. if not nn: spec = spec.strip() if spec[0] != '[' or spec[-1] != ']': raise click.BadOptionUsage( 'spec', 'VGSL spec {} not bracketed'.format(spec)) blocks = spec[1:-1].split(' ') m = re.match(r'(\d+),(\d+),(\d+),(\d+)', blocks[0]) if not m: raise click.BadOptionUsage( 'spec', 'Invalid input spec {}'.format(blocks[0])) batch, height, width, channels = [int(x) for x in m.groups()] else: batch, channels, height, width = nn.input try: transforms = generate_input_transforms(batch, height, width, channels, pad) except KrakenInputException as e: raise click.BadOptionUsage('spec', str(e)) # disable automatic partition when given evaluation set explicitly if evaluation_files: partition = 1 ground_truth = list(ground_truth) # merge training_files into ground_truth list if training_files: ground_truth.extend(training_files) if len(ground_truth) == 0: raise click.UsageError( 'No training data was provided to the train command. Use `-t` or the `ground_truth` argument.' ) np.random.shuffle(ground_truth) if len(ground_truth) > 2500 and not preload: logger.info( 'Disabling preloading for large (>2500) training data set. Enable by setting --preload parameter' ) preload = False # implicit preloading enabled for small data sets if preload is None: preload = True tr_im = ground_truth[:int(len(ground_truth) * partition)] if evaluation_files: logger.debug('Using {} lines from explicit eval set'.format( len(evaluation_files))) te_im = evaluation_files else: te_im = ground_truth[int(len(ground_truth) * partition):] logger.debug('Taking {} lines from training for evaluation'.format( len(te_im))) gt_set = GroundTruthDataset(normalization=normalization, reorder=reorder, im_transforms=transforms, preload=preload) with log.progressbar(tr_im, label='Building training set') as bar: for im in bar: logger.debug('Adding line {} to training set'.format(im)) try: gt_set.add(im) except FileNotFoundError as e: logger.warning('{}: {}. Skipping.'.format( e.strerror, e.filename)) except KrakenInputException as e: logger.warning(str(e)) val_set = GroundTruthDataset(normalization=normalization, reorder=reorder, im_transforms=transforms, preload=preload) with log.progressbar(te_im, label='Building validation set') as bar: for im in bar: logger.debug('Adding line {} to validation set'.format(im)) try: val_set.add(im) except FileNotFoundError as e: logger.warning('{}: {}. Skipping.'.format( e.strerror, e.filename)) except KrakenInputException as e: logger.warning(str(e)) logger.info( 'Training set {} lines, validation set {} lines, alphabet {} symbols'. format(len(gt_set._images), len(val_set._images), len(gt_set.alphabet))) alpha_diff = set(gt_set.alphabet).symmetric_difference( set(val_set.alphabet)) if alpha_diff: logger.warn('alphabet mismatch {}'.format(alpha_diff)) logger.info('grapheme\tcount') for k, v in sorted(gt_set.alphabet.items(), key=lambda x: x[1], reverse=True): char = make_printable(k) if char == k: char = '\t' + char logger.info(u'{}\t{}'.format(char, v)) logger.debug('Encoding training set') # use model codec when given if append: # is already loaded nn = cast(vgsl.TorchVGSLModel, nn) gt_set.encode(codec) message('Slicing and dicing model ', nl=False) # now we can create a new model spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label() + 1) logger.info('Appending {} to existing model {} after {}'.format( spec, nn.spec, append)) nn.append(append, spec) nn.add_codec(gt_set.codec) message('\u2713', fg='green') logger.info('Assembled model spec: {}'.format(nn.spec)) elif load: # is already loaded nn = cast(vgsl.TorchVGSLModel, nn) # prefer explicitly given codec over network codec if mode is 'both' codec = codec if (codec and resize == 'both') else nn.codec try: gt_set.encode(codec) except KrakenEncodeException as e: message('Network codec not compatible with training set') alpha_diff = set(gt_set.alphabet).difference(set(codec.c2l.keys())) if resize == 'fail': logger.error( 'Training data and model codec alphabets mismatch: {}'. format(alpha_diff)) ctx.exit(code=1) elif resize == 'add': message('Adding missing labels to network ', nl=False) logger.info( 'Resizing codec to include {} new code points'.format( len(alpha_diff))) codec.c2l.update({ k: [v] for v, k in enumerate(alpha_diff, start=codec.max_label() + 1) }) nn.add_codec(PytorchCodec(codec.c2l)) logger.info( 'Resizing last layer in network to {} outputs'.format( codec.max_label() + 1)) nn.resize_output(codec.max_label() + 1) message('\u2713', fg='green') elif resize == 'both': message('Fitting network exactly to training set ', nl=False) logger.info( 'Resizing network or given codec to {} code sequences'. format(len(gt_set.alphabet))) gt_set.encode(None) ncodec, del_labels = codec.merge(gt_set.codec) logger.info( 'Deleting {} output classes from network ({} retained)'. format(len(del_labels), len(codec) - len(del_labels))) gt_set.encode(ncodec) nn.resize_output(ncodec.max_label() + 1, del_labels) message('\u2713', fg='green') else: raise click.BadOptionUsage( 'resize', 'Invalid resize value {}'.format(resize)) else: gt_set.encode(codec) logger.info('Creating new model {} with {} outputs'.format( spec, gt_set.codec.max_label() + 1)) spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label() + 1) nn = vgsl.TorchVGSLModel(spec) # initialize weights message('Initializing model ', nl=False) nn.init_weights() nn.add_codec(gt_set.codec) # initialize codec message('\u2713', fg='green') train_loader = DataLoader(gt_set, batch_size=1, shuffle=True, pin_memory=True) # don't encode validation set as the alphabets may not match causing encoding failures val_set.training_set = list(zip(val_set._images, val_set._gt)) logger.debug('Constructing {} optimizer (lr: {}, momentum: {})'.format( optimizer, lrate, momentum)) # set mode to trainindg nn.train() # set number of OpenMP threads logger.debug('Set OpenMP threads to {}'.format(threads)) nn.set_num_threads(threads) logger.debug('Moving model to device {}'.format(device)) rec = models.TorchSeqRecognizer(nn, train=True, device=device) optim = getattr(torch.optim, optimizer)(nn.nn.parameters(), lr=0) tr_it = TrainScheduler(optim) if schedule == '1cycle': add_1cycle(tr_it, epochs * len(gt_set), lrate, momentum, momentum - 0.10, weight_decay) else: # constant learning rate scheduler tr_it.add_phase(1, (lrate, lrate), (momentum, momentum), weight_decay, train.annealing_const) st_it = cast(TrainStopper, None) # type: TrainStopper if quit == 'early': st_it = EarlyStopping(train_loader, min_delta, lag) elif quit == 'dumb': st_it = EpochStopping(train_loader, epochs) else: raise click.BadOptionUsage( 'quit', 'Invalid training interruption scheme {}'.format(quit)) for epoch, loader in enumerate(st_it): with log.progressbar(label='epoch {}/{}'.format( epoch, epochs - 1 if epochs > 0 else '∞'), length=len(loader), show_pos=True) as bar: acc_loss = torch.tensor(0.0).to(device, non_blocking=True) for trial, (input, target) in enumerate(loader): tr_it.step() input = input.to(device, non_blocking=True) target = target.to(device, non_blocking=True) input = input.requires_grad_() o = nn.nn(input) # height should be 1 by now if o.size(2) != 1: raise KrakenInputException( 'Expected dimension 3 to be 1, actual {}'.format( o.size(2))) o = o.squeeze(2) optim.zero_grad() # NCW -> WNC loss = nn.criterion( o.permute(2, 0, 1), # type: ignore target, (o.size(2), ), (target.size(1), )) logger.info('trial {}'.format(trial)) if not torch.isinf(loss): loss.backward() optim.step() else: logger.debug('infinite loss in trial {}'.format(trial)) bar.update(1) if not epoch % savefreq: logger.info('Saving to {}_{}'.format(output, epoch)) try: nn.save_model('{}_{}.mlmodel'.format(output, epoch)) except Exception as e: logger.error('Saving model failed: {}'.format(str(e))) if not epoch % report: logger.debug('Starting evaluation run') nn.eval() chars, error = compute_error(rec, list(val_set)) nn.train() accuracy = (chars - error) / chars logger.info('Accuracy report ({}) {:0.4f} {} {}'.format( epoch, accuracy, chars, error)) message('Accuracy report ({}) {:0.4f} {} {}'.format( epoch, accuracy, chars, error)) st_it.update(accuracy) if quit == 'early': message('Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'. format(output, st_it.best_epoch, st_it.best_loss)) logger.info( 'Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'. format(output, st_it.best_epoch, st_it.best_loss)) shutil.copy('{}_{}.mlmodel'.format(output, st_it.best_epoch), '{}_best.mlmodel'.format(output))
def recognizer(model, pad, no_segmentation, bidi_reordering, script_ignore, base_image, input, output, lines) -> None: import json import tempfile from kraken import rpred try: im = Image.open(base_image) except IOError as e: raise click.BadParameter(str(e)) ctx = click.get_current_context() # input may either be output from the segmenter then it is a JSON file or # be an image file when running the OCR subcommand alone. might still come # from some other subcommand though. scripts = set() if not lines and base_image != input: lines = input if not lines: if no_segmentation: lines = tempfile.NamedTemporaryFile(mode='w', delete=False) logger.info('Running in no_segmentation mode. Creating temporary segmentation {}.'.format(lines.name)) json.dump({'script_detection': False, 'text_direction': 'horizontal-lr', 'boxes': [(0, 0) + im.size]}, lines) lines.close() lines = lines.name else: raise click.UsageError('No line segmentation given. Add one with `-l` or run `segment` first.') elif no_segmentation: logger.warning('no_segmentation mode enabled but segmentation defined. Ignoring --no-segmentation option.') with open_file(lines, 'r') as fp: try: fp = cast(IO[Any], fp) bounds = json.load(fp) except ValueError as e: raise click.UsageError('{} invalid segmentation: {}'.format(lines, str(e))) # script detection if bounds['script_detection']: for l in bounds['boxes']: for t in l: scripts.add(t[0]) it = rpred.mm_rpred(model, im, bounds, pad, bidi_reordering=bidi_reordering, script_ignore=script_ignore) else: it = rpred.rpred(model['default'], im, bounds, pad, bidi_reordering=bidi_reordering) if not lines and no_segmentation: logger.debug('Removing temporary segmentation file.') os.unlink(lines.name) preds = [] with log.progressbar(it, label='Processing', length=len(bounds['boxes'])) as bar: for pred in bar: preds.append(pred) ctx = click.get_current_context() with open_file(output, 'w', encoding='utf-8') as fp: fp = cast(IO[Any], fp) message('Writing recognition results for {}\t'.format(base_image), nl=False) logger.info('Serializing as {} into {}'.format(ctx.meta['mode'], output)) if ctx.meta['mode'] != 'text': from kraken import serialization fp.write(serialization.serialize(preds, base_image, Image.open(base_image).size, ctx.meta['text_direction'], scripts, ctx.meta['mode'])) else: fp.write('\n'.join(s.prediction for s in preds)) message('\u2713', fg='green')