def compute_segmentation_map(im, mask: Optional[np.array] = None, model=None, device: str = 'cpu'): """ """ im_str = get_im_str(im) logger.info(f'Segmenting {im_str}') if model.input[ 1] == 1 and model.one_channel_mode == '1' and not is_bitonal(im): logger.warning('Running binary model on non-binary input image ' '(mode {}). This will result in severely degraded ' 'performance'.format(im.mode)) model.eval() model.to(device) if mask: if mask.mode != '1' and not is_bitonal(mask): logger.error('Mask is not bitonal') raise KrakenInputException('Mask is not bitonal') mask = mask.convert('1') if mask.size != im.size: logger.error( 'Mask size {mask.size} doesn\'t match image size {im.size}') raise KrakenInputException( 'Mask size {mask.size} doesn\'t match image size {im.size}') logger.info('Masking enabled in segmenter.') mask = pil2array(mask) batch, channels, height, width = model.input transforms = dataset.generate_input_transforms(batch, height, width, channels, 0, valid_norm=False) res_tf = tf.Compose(transforms.transforms[:3]) scal_im = res_tf(im).convert('L') with torch.no_grad(): logger.debug('Running network forward pass') o, _ = model.nn(transforms(im).unsqueeze(0).to(device)) logger.debug('Upsampling network output') o = F.interpolate(o, size=scal_im.size[::-1]) o = o.squeeze().cpu().numpy() scale = np.divide(im.size, o.shape[:0:-1]) bounding_regions = model.user_metadata[ 'bounding_regions'] if 'bounding_regions' in model.user_metadata else None return { 'heatmap': o, 'cls_map': model.user_metadata['class_mapping'], 'bounding_regions': bounding_regions, 'scale': scale, 'scal_im': scal_im }
def cli(model, files): import sys import torch from PIL import Image from kraken.lib import vgsl, dataset import torch.nn.functional as F from os.path import splitext import torchvision.transforms as tf model = vgsl.TorchVGSLModel.load_model(model) model.eval() batch, channels, height, width = model.input transforms = dataset.generate_input_transforms(batch, height, width, channels, 0, valid_norm=False) torch.set_num_threads(1) for img in files: print(img) im = Image.open(img) res_tf = tf.Compose(transforms.transforms[:3]) scal_im = res_tf(im) with torch.no_grad(): o, _ = model.nn(transforms(im).unsqueeze(0)) o = F.interpolate(o, size=scal_im.size[::-1]) o = o.squeeze().numpy() heat = Image.fromarray((o[2] * 255).astype('uint8')) heat.save(splitext(img)[0] + '.heat.png') overlay = Image.new('RGBA', scal_im.size, (0, 130, 200, 255)) bl = Image.composite(overlay, scal_im.convert('RGBA'), heat) heat = Image.fromarray((o[1] * 255).astype('uint8')) overlay = Image.new('RGBA', scal_im.size, (230, 25, 75, 255)) bl = Image.composite(overlay, bl, heat) heat = Image.fromarray((o[0] * 255).astype('uint8')) overlay = Image.new('RGBA', scal_im.size, (60, 180, 75, 255)) bl = Image.composite(overlay, bl, heat).save(splitext(img)[0] + '.overlay.png') del o del im
def test(ctx, model, evaluation_files, device, pad, threads, test_set): """ Evaluate on a test set. """ if not model: raise click.UsageError('No model to evaluate given.') import numpy as np from PIL import Image from kraken.serialization import render_report from kraken.lib import models from kraken.lib.dataset import global_align, compute_confusions, generate_input_transforms logger.info('Building test set from {} line images'.format(len(test_set) + len(evaluation_files))) nn = {} for p in model: message('Loading model {}\t'.format(p), nl=False) nn[p] = models.load_any(p) message('\u2713', fg='green') test_set = list(test_set) # set number of OpenMP threads logger.debug('Set OpenMP threads to {}'.format(threads)) next(iter(nn.values())).nn.set_num_threads(threads) # merge training_files into ground_truth list if evaluation_files: test_set.extend(evaluation_files) if len(test_set) == 0: raise click.UsageError('No evaluation data was provided to the test command. Use `-e` or the `test_set` argument.') def _get_text(im): with open(os.path.splitext(im)[0] + '.gt.txt', 'r') as fp: return get_display(fp.read()) acc_list = [] for p, net in nn.items(): algn_gt: List[str] = [] algn_pred: List[str] = [] chars = 0 error = 0 message('Evaluating {}'.format(p)) logger.info('Evaluating {}'.format(p)) batch, channels, height, width = net.nn.input ts = generate_input_transforms(batch, height, width, channels, pad) with log.progressbar(test_set, label='Evaluating') as bar: for im_path in bar: i = ts(Image.open(im_path)) text = _get_text(im_path) pred = net.predict_string(i) chars += len(text) c, algn1, algn2 = global_align(text, pred) algn_gt.extend(algn1) algn_pred.extend(algn2) error += c acc_list.append((chars-error)/chars) confusions, scripts, ins, dels, subs = compute_confusions(algn_gt, algn_pred) rep = render_report(p, chars, error, confusions, scripts, ins, dels, subs) logger.info(rep) message(rep) logger.info('Average accuracy: {:0.2f}%, (stddev: {:0.2f})'.format(np.mean(acc_list) * 100, np.std(acc_list) * 100)) message('Average accuracy: {:0.2f}%, (stddev: {:0.2f})'.format(np.mean(acc_list) * 100, np.std(acc_list) * 100))
def train(ctx, pad, output, spec, append, load, freq, quit, epochs, lag, min_delta, device, optimizer, lrate, momentum, weight_decay, schedule, partition, normalization, normalize_whitespace, codec, resize, reorder, training_files, evaluation_files, preload, threads, ground_truth): """ Trains a model from image-text pairs. """ if not load and append: raise click.BadOptionUsage('append', 'append option requires loading an existing model') if resize != 'fail' and not load: raise click.BadOptionUsage('resize', 'resize option requires loading an existing model') import re import torch import shutil import numpy as np from torch.utils.data import DataLoader from kraken.lib import models, vgsl, train from kraken.lib.util import make_printable from kraken.lib.train import EarlyStopping, EpochStopping, TrainStopper, TrainScheduler, add_1cycle from kraken.lib.codec import PytorchCodec from kraken.lib.dataset import GroundTruthDataset, generate_input_transforms logger.info('Building ground truth set from {} line images'.format(len(ground_truth) + len(training_files))) completed_epochs = 0 # load model if given. if a new model has to be created we need to do that # after data set initialization, otherwise to output size is still unknown. nn = None #hyper_fields = ['freq', 'quit', 'epochs', 'lag', 'min_delta', 'optimizer', 'lrate', 'momentum', 'weight_decay', 'schedule', 'partition', 'normalization', 'normalize_whitespace', 'reorder', 'preload', 'completed_epochs', 'output'] if load: logger.info('Loading existing model from {} '.format(load)) message('Loading existing model from {}'.format(load), nl=False) nn = vgsl.TorchVGSLModel.load_model(load) #if nn.user_metadata and load_hyper_parameters: # for param in hyper_fields: # if param in nn.user_metadata: # logger.info('Setting \'{}\' to \'{}\''.format(param, nn.user_metadata[param])) # message('Setting \'{}\' to \'{}\''.format(param, nn.user_metadata[param])) # locals()[param] = nn.user_metadata[param] message('\u2713', fg='green', nl=False) # preparse input sizes from vgsl string to seed ground truth data set # sizes and dimension ordering. if not nn: spec = spec.strip() if spec[0] != '[' or spec[-1] != ']': raise click.BadOptionUsage('spec', 'VGSL spec {} not bracketed'.format(spec)) blocks = spec[1:-1].split(' ') m = re.match(r'(\d+),(\d+),(\d+),(\d+)', blocks[0]) if not m: raise click.BadOptionUsage('spec', 'Invalid input spec {}'.format(blocks[0])) batch, height, width, channels = [int(x) for x in m.groups()] else: batch, channels, height, width = nn.input try: transforms = generate_input_transforms(batch, height, width, channels, pad) except KrakenInputException as e: raise click.BadOptionUsage('spec', str(e)) # disable automatic partition when given evaluation set explicitly if evaluation_files: partition = 1 ground_truth = list(ground_truth) # merge training_files into ground_truth list if training_files: ground_truth.extend(training_files) if len(ground_truth) == 0: raise click.UsageError('No training data was provided to the train command. Use `-t` or the `ground_truth` argument.') np.random.shuffle(ground_truth) if len(ground_truth) > 2500 and not preload: logger.info('Disabling preloading for large (>2500) training data set. Enable by setting --preload parameter') preload = False # implicit preloading enabled for small data sets if preload is None: preload = True tr_im = ground_truth[:int(len(ground_truth) * partition)] if evaluation_files: logger.debug('Using {} lines from explicit eval set'.format(len(evaluation_files))) te_im = evaluation_files else: te_im = ground_truth[int(len(ground_truth) * partition):] logger.debug('Taking {} lines from training for evaluation'.format(len(te_im))) # set multiprocessing tensor sharing strategy if 'file_system' in torch.multiprocessing.get_all_sharing_strategies(): logger.debug('Setting multiprocessing tensor sharing strategy to file_system') torch.multiprocessing.set_sharing_strategy('file_system') gt_set = GroundTruthDataset(normalization=normalization, whitespace_normalization=normalize_whitespace, reorder=reorder, im_transforms=transforms, preload=preload) with log.progressbar(tr_im, label='Building training set') as bar: for im in bar: logger.debug('Adding line {} to training set'.format(im)) try: gt_set.add(im) except FileNotFoundError as e: logger.warning('{}: {}. Skipping.'.format(e.strerror, e.filename)) except KrakenInputException as e: logger.warning(str(e)) val_set = GroundTruthDataset(normalization=normalization, whitespace_normalization=normalize_whitespace, reorder=reorder, im_transforms=transforms, preload=preload) with log.progressbar(te_im, label='Building validation set') as bar: for im in bar: logger.debug('Adding line {} to validation set'.format(im)) try: val_set.add(im) except FileNotFoundError as e: logger.warning('{}: {}. Skipping.'.format(e.strerror, e.filename)) except KrakenInputException as e: logger.warning(str(e)) logger.info('Training set {} lines, validation set {} lines, alphabet {} symbols'.format(len(gt_set._images), len(val_set._images), len(gt_set.alphabet))) alpha_diff_only_train = set(gt_set.alphabet).difference(set(val_set.alphabet)) alpha_diff_only_val = set(val_set.alphabet).difference(set(gt_set.alphabet)) if alpha_diff_only_train: logger.warning('alphabet mismatch: chars in training set only: {} (not included in accuracy test during training)'.format(alpha_diff_only_train)) if alpha_diff_only_val: logger.warning('alphabet mismatch: chars in validation set only: {} (not trained)'.format(alpha_diff_only_val)) logger.info('grapheme\tcount') for k, v in sorted(gt_set.alphabet.items(), key=lambda x: x[1], reverse=True): char = make_printable(k) if char == k: char = '\t' + char logger.info(u'{}\t{}'.format(char, v)) logger.debug('Encoding training set') # use model codec when given if append: # is already loaded nn = cast(vgsl.TorchVGSLModel, nn) gt_set.encode(codec) message('Slicing and dicing model ', nl=False) # now we can create a new model spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label()+1) logger.info('Appending {} to existing model {} after {}'.format(spec, nn.spec, append)) nn.append(append, spec) nn.add_codec(gt_set.codec) message('\u2713', fg='green') logger.info('Assembled model spec: {}'.format(nn.spec)) elif load: # is already loaded nn = cast(vgsl.TorchVGSLModel, nn) # prefer explicitly given codec over network codec if mode is 'both' codec = codec if (codec and resize == 'both') else nn.codec try: gt_set.encode(codec) except KrakenEncodeException as e: message('Network codec not compatible with training set') alpha_diff = set(gt_set.alphabet).difference(set(codec.c2l.keys())) if resize == 'fail': logger.error('Training data and model codec alphabets mismatch: {}'.format(alpha_diff)) ctx.exit(code=1) elif resize == 'add': message('Adding missing labels to network ', nl=False) logger.info('Resizing codec to include {} new code points'.format(len(alpha_diff))) codec.c2l.update({k: [v] for v, k in enumerate(alpha_diff, start=codec.max_label()+1)}) nn.add_codec(PytorchCodec(codec.c2l)) logger.info('Resizing last layer in network to {} outputs'.format(codec.max_label()+1)) nn.resize_output(codec.max_label()+1) gt_set.encode(nn.codec) message('\u2713', fg='green') elif resize == 'both': message('Fitting network exactly to training set ', nl=False) logger.info('Resizing network or given codec to {} code sequences'.format(len(gt_set.alphabet))) gt_set.encode(None) ncodec, del_labels = codec.merge(gt_set.codec) logger.info('Deleting {} output classes from network ({} retained)'.format(len(del_labels), len(codec)-len(del_labels))) gt_set.encode(ncodec) nn.resize_output(ncodec.max_label()+1, del_labels) message('\u2713', fg='green') else: raise click.BadOptionUsage('resize', 'Invalid resize value {}'.format(resize)) else: gt_set.encode(codec) logger.info('Creating new model {} with {} outputs'.format(spec, gt_set.codec.max_label()+1)) spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label()+1) nn = vgsl.TorchVGSLModel(spec) # initialize weights message('Initializing model ', nl=False) nn.init_weights() nn.add_codec(gt_set.codec) # initialize codec message('\u2713', fg='green') # half the number of data loading processes if device isn't cuda and we haven't enabled preloading if device == 'cpu' and not preload: loader_threads = threads // 2 else: loader_threads = threads train_loader = DataLoader(gt_set, batch_size=1, shuffle=True, num_workers=loader_threads, pin_memory=True) threads -= loader_threads # don't encode validation set as the alphabets may not match causing encoding failures val_set.training_set = list(zip(val_set._images, val_set._gt)) logger.debug('Constructing {} optimizer (lr: {}, momentum: {})'.format(optimizer, lrate, momentum)) # set mode to trainindg nn.train() # set number of OpenMP threads logger.debug('Set OpenMP threads to {}'.format(threads)) nn.set_num_threads(threads) logger.debug('Moving model to device {}'.format(device)) optim = getattr(torch.optim, optimizer)(nn.nn.parameters(), lr=0) if 'accuracy' not in nn.user_metadata: nn.user_metadata['accuracy'] = [] tr_it = TrainScheduler(optim) if schedule == '1cycle': add_1cycle(tr_it, int(len(gt_set) * epochs), lrate, momentum, momentum - 0.10, weight_decay) else: # constant learning rate scheduler tr_it.add_phase(1, (lrate, lrate), (momentum, momentum), weight_decay, train.annealing_const) if quit == 'early': st_it = EarlyStopping(min_delta, lag) elif quit == 'dumb': st_it = EpochStopping(epochs - completed_epochs) else: raise click.BadOptionUsage('quit', 'Invalid training interruption scheme {}'.format(quit)) #for param in hyper_fields: # logger.debug('Setting \'{}\' to \'{}\' in model metadata'.format(param, locals()[param])) # nn.user_metadata[param] = locals()[param] trainer = train.KrakenTrainer(model=nn, optimizer=optim, device=device, filename_prefix=output, event_frequency=freq, train_set=train_loader, val_set=val_set, stopper=st_it) trainer.add_lr_scheduler(tr_it) with log.progressbar(label='stage {}/{}'.format(1, trainer.stopper.epochs if trainer.stopper.epochs > 0 else '∞'), length=trainer.event_it, show_pos=True) as bar: def _draw_progressbar(): bar.update(1) def _print_eval(epoch, accuracy, chars, error): message('Accuracy report ({}) {:0.4f} {} {}'.format(epoch, accuracy, chars, error)) # reset progress bar bar.label = 'stage {}/{}'.format(epoch+1, trainer.stopper.epochs if trainer.stopper.epochs > 0 else '∞') bar.pos = 0 bar.finished = False trainer.run(_print_eval, _draw_progressbar) if quit == 'early': message('Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'.format(output, trainer.stopper.best_epoch, trainer.stopper.best_loss)) logger.info('Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'.format(output, trainer.stopper.best_epoch, trainer.stopper.best_loss)) shutil.copy('{}_{}.mlmodel'.format(output, trainer.stopper.best_epoch), '{}_best.mlmodel'.format(output))
def rpred(network: TorchSeqRecognizer, im: Image.Image, bounds: dict, pad: int = 16, bidi_reordering: bool = True) -> Generator[ocr_record, None, None]: """ Uses a RNN to recognize text Args: network (kraken.lib.models.TorchSeqRecognizer): A TorchSegRecognizer object im (PIL.Image.Image): Image to extract text from bounds (dict): A dictionary containing a 'boxes' entry with a list of coordinates (x0, y0, x1, y1) of a text line in the image and an entry 'text_direction' containing 'horizontal-lr/rl/vertical-lr/rl'. pad (int): Extra blank padding to the left and right of text line. Auto-disabled when expected network inputs are incompatible with padding. bidi_reordering (bool): Reorder classes in the ocr_record according to the Unicode bidirectional algorithm for correct display. Yields: An ocr_record containing the recognized text, absolute character positions, and confidence values for each character. """ im_str = get_im_str(im) logger.info('Running recognizer on {} with {} lines'.format( im_str, len(bounds['boxes']))) logger.debug('Loading line transform') batch, channels, height, width = network.nn.input ts = generate_input_transforms(batch, height, width, channels, pad) for box, coords in extract_boxes(im, bounds): # check if boxes are non-zero in any dimension if sum(coords[::2]) == 0 or coords[3] - coords[1] == 0: logger.warning( 'bbox {} with zero dimension. Emitting empty record.'.format( coords)) yield ocr_record('', [], []) continue # try conversion into tensor try: line = ts(box) except Exception: yield ocr_record('', [], []) continue # check if line is non-zero if line.max() == line.min(): yield ocr_record('', [], []) continue preds = network.predict(line) # calculate recognized LSTM locations of characters # scale between network output and network input net_scale = line.shape[2] / network.outputs.shape[1] # scale between network input and original line in_scale = box.size[0] / (line.shape[2] - 2 * pad) def _scale_val(val, min_val, max_val): return int( round( min(max(((val * net_scale) - pad) * in_scale, min_val), max_val))) # XXX: fix bounding box calculation ocr_record for multi-codepoint labels. pred = ''.join(x[0] for x in preds) pos = [] conf = [] for _, start, end, c in preds: if bounds['text_direction'].startswith('horizontal'): xmin = coords[0] + _scale_val(start, 0, box.size[0]) xmax = coords[0] + _scale_val(end, 0, box.size[0]) pos.append((xmin, coords[1], xmax, coords[3])) else: ymin = coords[1] + _scale_val(start, 0, box.size[1]) ymax = coords[1] + _scale_val(start, 0, box.size[1]) pos.append((coords[0], ymin, coords[2], ymax)) conf.append(c) if bidi_reordering: logger.debug('BiDi reordering record.') yield bidi_record(ocr_record(pred, pos, conf)) else: logger.debug('Emitting raw record') yield ocr_record(pred, pos, conf)
def mm_rpred( nets: Dict[str, TorchSeqRecognizer], im: Image.Image, bounds: dict, pad: int = 16, bidi_reordering: bool = True, script_ignore: Optional[List[str]] = None ) -> Generator[ocr_record, None, None]: """ Multi-model version of kraken.rpred.rpred. Takes a dictionary of ISO15924 script identifiers->models and an script-annotated segmentation to dynamically select appropriate models for these lines. Args: nets (dict): A dict mapping ISO15924 identifiers to TorchSegRecognizer objects. Recommended to be an defaultdict. im (PIL.Image.Image): Image to extract text from bounds (dict): A dictionary containing a 'boxes' entry with a list of lists of coordinates (script, (x0, y0, x1, y1)) of a text line in the image and an entry 'text_direction' containing 'horizontal-lr/rl/vertical-lr/rl'. pad (int): Extra blank padding to the left and right of text line bidi_reordering (bool): Reorder classes in the ocr_record according to the Unicode bidirectional algorithm for correct display. script_ignore (list): List of scripts to ignore during recognition Yields: An ocr_record containing the recognized text, absolute character positions, and confidence values for each character. Raises: KrakenInputException if the mapping between segmentation scripts and networks is incomplete. """ im_str = get_im_str(im) logger.info( 'Running {} multi-script recognizers on {} with {} lines'.format( len(nets), im_str, len(bounds['boxes']))) miss = [x[0] for x in bounds['boxes'] if not nets.get(x[0])] if miss: raise KrakenInputException( 'Missing models for scripts {}'.format(miss)) # build dictionary for line preprocessing ts = {} for script, network in nets.items(): logger.debug('Loading line transforms for {}'.format(script)) batch, channels, height, width = network.nn.input ts[script] = generate_input_transforms(batch, height, width, channels, pad) for line in bounds['boxes']: rec = ocr_record('', [], []) for script, (box, coords) in zip( map(lambda x: x[0], line), extract_boxes( im, { 'text_direction': bounds['text_direction'], 'boxes': map(lambda x: x[1], line) })): # skip if script is set to ignore if script_ignore is not None and script in script_ignore: logger.info('Ignoring {} line segment.'.format(script)) continue # check if boxes are non-zero in any dimension if sum(coords[::2]) == 0 or coords[3] - coords[1] == 0: logger.warning('Run with zero dimension. Skipping.') continue # try conversion into tensor try: logger.debug('Preparing run.') line = ts[script](box) except Exception: logger.warning( 'Conversion of line {} failed. Skipping.'.format(coords)) yield ocr_record('', [], []) continue # check if line is non-zero if line.max() == line.min(): logger.warning('Empty run. Skipping.') yield ocr_record('', [], []) continue logger.debug('Forward pass with model {}'.format(script)) preds = nets[script].predict(line) # calculate recognized LSTM locations of characters logger.debug('Convert to absolute coordinates') scale = box.size[0] / (len(nets[script].outputs) - 2 * pad) pred = ''.join(x[0] for x in preds) pos = [] conf = [] for _, start, end, c in preds: if bounds['text_direction'].startswith('horizontal'): xmin = coords[0] + int(max((start - pad) * scale, 0)) xmax = coords[0] + max( int(min( (end - pad) * scale, coords[2] - coords[0])), 1) pos.append((xmin, coords[1], xmax, coords[3])) else: ymin = coords[1] + int(max((start - pad) * scale, 0)) ymax = coords[1] + max( int(min( (end - pad) * scale, coords[3] - coords[1])), 1) pos.append((coords[0], ymin, coords[2], ymax)) conf.append(c) rec.prediction += pred rec.cuts.extend(pos) rec.confidences.extend(conf) if bidi_reordering: logger.debug('BiDi reordering record.') yield bidi_record(rec) else: logger.debug('Emitting raw record') yield rec
def test(ctx, model, evaluation_files, device, pad, threads, test_set): """ Evaluate on a test set. """ if not model: raise click.UsageError('No model to evaluate given.') import numpy as np from PIL import Image from kraken.serialization import render_report from kraken.lib import models from kraken.lib.dataset import global_align, compute_confusions, generate_input_transforms logger.info('Building test set from {} line images'.format( len(test_set) + len(evaluation_files))) nn = {} for p in model: message('Loading model {}\t'.format(p), nl=False) nn[p] = models.load_any(p) message('\u2713', fg='green') test_set = list(test_set) # set number of OpenMP threads logger.debug('Set OpenMP threads to {}'.format(threads)) next(iter(nn.values())).nn.set_num_threads(threads) # merge training_files into ground_truth list if evaluation_files: test_set.extend(evaluation_files) if len(test_set) == 0: raise click.UsageError( 'No evaluation data was provided to the test command. Use `-e` or the `test_set` argument.' ) def _get_text(im): with open(os.path.splitext(im)[0] + '.gt.txt', 'r') as fp: return get_display(fp.read()) acc_list = [] for p, net in nn.items(): algn_gt: List[str] = [] algn_pred: List[str] = [] chars = 0 error = 0 message('Evaluating {}'.format(p)) logger.info('Evaluating {}'.format(p)) batch, channels, height, width = net.nn.input ts = generate_input_transforms(batch, height, width, channels, pad) with log.progressbar(test_set, label='Evaluating') as bar: for im_path in bar: i = ts(Image.open(im_path)) text = _get_text(im_path) pred = net.predict_string(i) chars += len(text) c, algn1, algn2 = global_align(text, pred) algn_gt.extend(algn1) algn_pred.extend(algn2) error += c acc_list.append((chars - error) / chars) confusions, scripts, ins, dels, subs = compute_confusions( algn_gt, algn_pred) rep = render_report(p, chars, error, confusions, scripts, ins, dels, subs) logger.info(rep) message(rep) logger.info('Average accuracy: {:0.2f}%, (stddev: {:0.2f})'.format( np.mean(acc_list) * 100, np.std(acc_list) * 100)) message('Average accuracy: {:0.2f}%, (stddev: {:0.2f})'.format( np.mean(acc_list) * 100, np.std(acc_list) * 100))
def train(ctx, pad, output, spec, append, load, freq, quit, epochs, lag, min_delta, device, optimizer, lrate, momentum, weight_decay, schedule, partition, normalization, normalize_whitespace, codec, resize, reorder, training_files, evaluation_files, preload, threads, ground_truth): """ Trains a model from image-text pairs. """ if not load and append: raise click.BadOptionUsage( 'append', 'append option requires loading an existing model') if resize != 'fail' and not load: raise click.BadOptionUsage( 'resize', 'resize option requires loading an existing model') import re import torch import shutil import numpy as np from torch.utils.data import DataLoader from kraken.lib import models, vgsl, train from kraken.lib.util import make_printable from kraken.lib.train import EarlyStopping, EpochStopping, TrainStopper, TrainScheduler, add_1cycle from kraken.lib.codec import PytorchCodec from kraken.lib.dataset import GroundTruthDataset, generate_input_transforms logger.info('Building ground truth set from {} line images'.format( len(ground_truth) + len(training_files))) completed_epochs = 0 # load model if given. if a new model has to be created we need to do that # after data set initialization, otherwise to output size is still unknown. nn = None #hyper_fields = ['freq', 'quit', 'epochs', 'lag', 'min_delta', 'optimizer', 'lrate', 'momentum', 'weight_decay', 'schedule', 'partition', 'normalization', 'normalize_whitespace', 'reorder', 'preload', 'completed_epochs', 'output'] if load: logger.info('Loading existing model from {} '.format(load)) message('Loading existing model from {}'.format(load), nl=False) nn = vgsl.TorchVGSLModel.load_model(load) #if nn.user_metadata and load_hyper_parameters: # for param in hyper_fields: # if param in nn.user_metadata: # logger.info('Setting \'{}\' to \'{}\''.format(param, nn.user_metadata[param])) # message('Setting \'{}\' to \'{}\''.format(param, nn.user_metadata[param])) # locals()[param] = nn.user_metadata[param] message('\u2713', fg='green', nl=False) # preparse input sizes from vgsl string to seed ground truth data set # sizes and dimension ordering. if not nn: spec = spec.strip() if spec[0] != '[' or spec[-1] != ']': raise click.BadOptionUsage( 'spec', 'VGSL spec {} not bracketed'.format(spec)) blocks = spec[1:-1].split(' ') m = re.match(r'(\d+),(\d+),(\d+),(\d+)', blocks[0]) if not m: raise click.BadOptionUsage( 'spec', 'Invalid input spec {}'.format(blocks[0])) batch, height, width, channels = [int(x) for x in m.groups()] else: batch, channels, height, width = nn.input try: transforms = generate_input_transforms(batch, height, width, channels, pad) except KrakenInputException as e: raise click.BadOptionUsage('spec', str(e)) # disable automatic partition when given evaluation set explicitly if evaluation_files: partition = 1 ground_truth = list(ground_truth) # merge training_files into ground_truth list if training_files: ground_truth.extend(training_files) if len(ground_truth) == 0: raise click.UsageError( 'No training data was provided to the train command. Use `-t` or the `ground_truth` argument.' ) np.random.shuffle(ground_truth) if len(ground_truth) > 2500 and not preload: logger.info( 'Disabling preloading for large (>2500) training data set. Enable by setting --preload parameter' ) preload = False # implicit preloading enabled for small data sets if preload is None: preload = True tr_im = ground_truth[:int(len(ground_truth) * partition)] if evaluation_files: logger.debug('Using {} lines from explicit eval set'.format( len(evaluation_files))) te_im = evaluation_files else: te_im = ground_truth[int(len(ground_truth) * partition):] logger.debug('Taking {} lines from training for evaluation'.format( len(te_im))) # set multiprocessing tensor sharing strategy if 'file_system' in torch.multiprocessing.get_all_sharing_strategies(): logger.debug( 'Setting multiprocessing tensor sharing strategy to file_system') torch.multiprocessing.set_sharing_strategy('file_system') gt_set = GroundTruthDataset(normalization=normalization, whitespace_normalization=normalize_whitespace, reorder=reorder, im_transforms=transforms, preload=preload) with log.progressbar(tr_im, label='Building training set') as bar: for im in bar: logger.debug('Adding line {} to training set'.format(im)) try: gt_set.add(im) except FileNotFoundError as e: logger.warning('{}: {}. Skipping.'.format( e.strerror, e.filename)) except KrakenInputException as e: logger.warning(str(e)) val_set = GroundTruthDataset(normalization=normalization, whitespace_normalization=normalize_whitespace, reorder=reorder, im_transforms=transforms, preload=preload) with log.progressbar(te_im, label='Building validation set') as bar: for im in bar: logger.debug('Adding line {} to validation set'.format(im)) try: val_set.add(im) except FileNotFoundError as e: logger.warning('{}: {}. Skipping.'.format( e.strerror, e.filename)) except KrakenInputException as e: logger.warning(str(e)) logger.info( 'Training set {} lines, validation set {} lines, alphabet {} symbols'. format(len(gt_set._images), len(val_set._images), len(gt_set.alphabet))) alpha_diff = set(gt_set.alphabet).symmetric_difference( set(val_set.alphabet)) if alpha_diff: logger.warn('alphabet mismatch {}'.format(alpha_diff)) logger.info('grapheme\tcount') for k, v in sorted(gt_set.alphabet.items(), key=lambda x: x[1], reverse=True): char = make_printable(k) if char == k: char = '\t' + char logger.info(u'{}\t{}'.format(char, v)) logger.debug('Encoding training set') # use model codec when given if append: # is already loaded nn = cast(vgsl.TorchVGSLModel, nn) gt_set.encode(codec) message('Slicing and dicing model ', nl=False) # now we can create a new model spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label() + 1) logger.info('Appending {} to existing model {} after {}'.format( spec, nn.spec, append)) nn.append(append, spec) nn.add_codec(gt_set.codec) message('\u2713', fg='green') logger.info('Assembled model spec: {}'.format(nn.spec)) elif load: # is already loaded nn = cast(vgsl.TorchVGSLModel, nn) # prefer explicitly given codec over network codec if mode is 'both' codec = codec if (codec and resize == 'both') else nn.codec try: gt_set.encode(codec) except KrakenEncodeException as e: message('Network codec not compatible with training set') alpha_diff = set(gt_set.alphabet).difference(set(codec.c2l.keys())) if resize == 'fail': logger.error( 'Training data and model codec alphabets mismatch: {}'. format(alpha_diff)) ctx.exit(code=1) elif resize == 'add': message('Adding missing labels to network ', nl=False) logger.info( 'Resizing codec to include {} new code points'.format( len(alpha_diff))) codec.c2l.update({ k: [v] for v, k in enumerate(alpha_diff, start=codec.max_label() + 1) }) nn.add_codec(PytorchCodec(codec.c2l)) logger.info( 'Resizing last layer in network to {} outputs'.format( codec.max_label() + 1)) nn.resize_output(codec.max_label() + 1) gt_set.encode(nn.codec) message('\u2713', fg='green') elif resize == 'both': message('Fitting network exactly to training set ', nl=False) logger.info( 'Resizing network or given codec to {} code sequences'. format(len(gt_set.alphabet))) gt_set.encode(None) ncodec, del_labels = codec.merge(gt_set.codec) logger.info( 'Deleting {} output classes from network ({} retained)'. format(len(del_labels), len(codec) - len(del_labels))) gt_set.encode(ncodec) nn.resize_output(ncodec.max_label() + 1, del_labels) message('\u2713', fg='green') else: raise click.BadOptionUsage( 'resize', 'Invalid resize value {}'.format(resize)) else: gt_set.encode(codec) logger.info('Creating new model {} with {} outputs'.format( spec, gt_set.codec.max_label() + 1)) spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label() + 1) nn = vgsl.TorchVGSLModel(spec) # initialize weights message('Initializing model ', nl=False) nn.init_weights() nn.add_codec(gt_set.codec) # initialize codec message('\u2713', fg='green') # half the number of data loading processes if device isn't cuda and we haven't enabled preloading if device == 'cpu' and not preload: loader_threads = threads // 2 else: loader_threads = threads train_loader = DataLoader(gt_set, batch_size=1, shuffle=True, num_workers=loader_threads, pin_memory=True) threads -= loader_threads # don't encode validation set as the alphabets may not match causing encoding failures val_set.training_set = list(zip(val_set._images, val_set._gt)) logger.debug('Constructing {} optimizer (lr: {}, momentum: {})'.format( optimizer, lrate, momentum)) # set mode to trainindg nn.train() # set number of OpenMP threads logger.debug('Set OpenMP threads to {}'.format(threads)) nn.set_num_threads(threads) logger.debug('Moving model to device {}'.format(device)) optim = getattr(torch.optim, optimizer)(nn.nn.parameters(), lr=0) if 'accuracy' not in nn.user_metadata: nn.user_metadata['accuracy'] = [] tr_it = TrainScheduler(optim) if schedule == '1cycle': add_1cycle(tr_it, int(len(gt_set) * epochs), lrate, momentum, momentum - 0.10, weight_decay) else: # constant learning rate scheduler tr_it.add_phase(1, (lrate, lrate), (momentum, momentum), weight_decay, train.annealing_const) if quit == 'early': st_it = EarlyStopping(min_delta, lag) elif quit == 'dumb': st_it = EpochStopping(epochs - completed_epochs) else: raise click.BadOptionUsage( 'quit', 'Invalid training interruption scheme {}'.format(quit)) #for param in hyper_fields: # logger.debug('Setting \'{}\' to \'{}\' in model metadata'.format(param, locals()[param])) # nn.user_metadata[param] = locals()[param] trainer = train.KrakenTrainer(model=nn, optimizer=optim, device=device, filename_prefix=output, event_frequency=freq, train_set=train_loader, val_set=val_set, stopper=st_it) trainer.add_lr_scheduler(tr_it) with log.progressbar(label='stage {}/{}'.format( 1, trainer.stopper.epochs if trainer.stopper.epochs > 0 else '∞'), length=trainer.event_it, show_pos=True) as bar: def _draw_progressbar(): bar.update(1) def _print_eval(epoch, accuracy, chars, error): message('Accuracy report ({}) {:0.4f} {} {}'.format( epoch, accuracy, chars, error)) # reset progress bar bar.label = 'stage {}/{}'.format( epoch + 1, trainer.stopper.epochs if trainer.stopper.epochs > 0 else '∞') bar.pos = 0 bar.finished = False trainer.run(_print_eval, _draw_progressbar) if quit == 'early': message('Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'. format(output, trainer.stopper.best_epoch, trainer.stopper.best_loss)) logger.info( 'Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'. format(output, trainer.stopper.best_epoch, trainer.stopper.best_loss)) shutil.copy('{}_{}.mlmodel'.format(output, trainer.stopper.best_epoch), '{}_best.mlmodel'.format(output))
def segmentation_train_gen( cls, hyper_params: Dict = default_specs.SEGMENTATION_HYPER_PARAMS, progress_callback: Callable[[str, int], Callable[ [None], None]] = lambda string, length: lambda: None, message: Callable[[str], None] = lambda *args, **kwargs: None, output: str = 'model', spec: str = default_specs.SEGMENTATION_SPEC, load: Optional[str] = None, device: str = 'cpu', training_data: Sequence[Dict] = None, evaluation_data: Sequence[Dict] = None, threads: int = 1, load_hyper_parameters: bool = False, force_binarization: bool = False, format_type: str = 'path', suppress_regions: bool = False, suppress_baselines: bool = False, valid_regions: Optional[Sequence[str]] = None, valid_baselines: Optional[Sequence[str]] = None, merge_regions: Optional[Dict[str, str]] = None, merge_baselines: Optional[Dict[str, str]] = None, augment: bool = False): """ This is an ugly constructor that takes all the arguments from the command line driver, finagles the datasets, models, and hyperparameters correctly and returns a KrakenTrainer object. Setup parameters (load, training_data, evaluation_data, ....) are named, model hyperparameters (everything in kraken.lib.default_specs.SEGMENTATION_HYPER_PARAMS) are in in the `hyper_params` argument. Args: hyper_params (dict): Hyperparameter dictionary containing all fields from kraken.lib.default_specs.SEGMENTATION_HYPER_PARAMS progress_callback (Callable): Callback for progress reports on various computationally expensive processes. A human readable string and the process length is supplied. The callback has to return another function which will be executed after each step. message (Callable): Messaging printing method for above log but below warning level output, i.e. infos that should generally be shown to users. **kwargs: Setup parameters, i.e. CLI parameters of the train() command. Returns: A KrakenTrainer object. """ # load model if given. if a new model has to be created we need to do that # after data set initialization, otherwise to output size is still unknown. nn = None if load: logger.info(f'Loading existing model from {load} ') message(f'Loading existing model from {load} ', nl=False) nn = vgsl.TorchVGSLModel.load_model(load) if load_hyper_parameters: hyper_params.update(nn.hyper_params) nn.hyper_params = hyper_params message('\u2713', fg='green', nl=False) # preparse input sizes from vgsl string to seed ground truth data set # sizes and dimension ordering. if not nn: spec = spec.strip() if spec[0] != '[' or spec[-1] != ']': logger.error(f'VGSL spec "{spec}" not bracketed') return None blocks = spec[1:-1].split(' ') m = re.match(r'(\d+),(\d+),(\d+),(\d+)', blocks[0]) if not m: logger.error(f'Invalid input spec {blocks[0]}') return None batch, height, width, channels = [int(x) for x in m.groups()] else: batch, channels, height, width = nn.input try: transforms = generate_input_transforms(batch, height, width, channels, 0, valid_norm=False) except KrakenInputException as e: logger.error(f'Spec error: {e}') return None # set multiprocessing tensor sharing strategy if 'file_system' in torch.multiprocessing.get_all_sharing_strategies(): logger.debug( 'Setting multiprocessing tensor sharing strategy to file_system' ) torch.multiprocessing.set_sharing_strategy('file_system') if not valid_regions: valid_regions = None if not valid_baselines: valid_baselines = None if suppress_regions: valid_regions = [] merge_regions = None if suppress_baselines: valid_baselines = [] merge_baselines = None gt_set = BaselineSet(training_data, line_width=hyper_params['line_width'], im_transforms=transforms, mode=format_type, augmentation=hyper_params['augment'], valid_baselines=valid_baselines, merge_baselines=merge_baselines, valid_regions=valid_regions, merge_regions=merge_regions) val_set = BaselineSet(evaluation_data, line_width=hyper_params['line_width'], im_transforms=transforms, mode=format_type, augmentation=hyper_params['augment'], valid_baselines=valid_baselines, merge_baselines=merge_baselines, valid_regions=valid_regions, merge_regions=merge_regions) if format_type == None: for page in training_data: gt_set.add(**page) for page in evaluation_data: val_set.add(**page) # overwrite class mapping in validation set val_set.num_classes = gt_set.num_classes val_set.class_mapping = gt_set.class_mapping if not load: spec = f'[{spec[1:-1]} O2l{gt_set.num_classes}]' message( f'Creating model {spec} with {gt_set.num_classes} outputs ', nl=False) nn = vgsl.TorchVGSLModel(spec) message('\u2713', fg='green') message('Training line types:') for k, v in gt_set.class_mapping['baselines'].items(): message(f' {k}\t{v}') message('Training region types:') for k, v in gt_set.class_mapping['regions'].items(): message(f' {k}\t{v}') if len(gt_set.imgs) == 0: logger.error( 'No valid training data was provided to the train command. Please add valid XML data.' ) return None if device == 'cpu': loader_threads = threads // 2 else: loader_threads = threads train_loader = InfiniteDataLoader(gt_set, batch_size=1, shuffle=True, num_workers=loader_threads, pin_memory=True) val_loader = DataLoader(val_set, batch_size=1, shuffle=True, num_workers=loader_threads, pin_memory=True) threads = max((threads - loader_threads, 1)) # set model type metadata field and dump class_mapping nn.model_type = 'segmentation' nn.user_metadata['class_mapping'] = val_set.class_mapping # set mode to training nn.train() logger.debug(f'Set OpenMP threads to {threads}') nn.set_num_threads(threads) optim = getattr(torch.optim, hyper_params['optimizer'])(nn.nn.parameters(), lr=0) tr_it = TrainScheduler(optim) if hyper_params['schedule'] == '1cycle': add_1cycle(tr_it, int(len(gt_set) * hyper_params['epochs']), hyper_params['lrate'], hyper_params['momentum'], hyper_params['momentum'] - 0.10, hyper_params['weight_decay']) elif hyper_params['schedule'] == 'exponential': add_exponential_decay(tr_it, int(len(gt_set) * hyper_params['epochs']), len(gt_set), hyper_params['lrate'], 0.95, hyper_params['momentum'], hyper_params['weight_decay']) else: # constant learning rate scheduler tr_it.add_phase(1, 2 * (hyper_params['lrate'], ), 2 * (hyper_params['momentum'], ), hyper_params['weight_decay'], annealing_const) if hyper_params['quit'] == 'early': st_it = EarlyStopping(hyper_params['min_delta'], hyper_params['lag']) elif hyper_params['quit'] == 'dumb': st_it = EpochStopping(hyper_params['epochs'] - hyper_params['completed_epochs']) else: logger.error(f'Invalid training interruption scheme {quit}') return None trainer = cls(model=nn, optimizer=optim, device=device, filename_prefix=output, event_frequency=hyper_params['freq'], train_set=train_loader, val_set=val_loader, stopper=st_it, loss_fn=baseline_label_loss_fn, evaluator=baseline_label_evaluator_fn) trainer.add_lr_scheduler(tr_it) return trainer
def recognition_train_gen( cls, hyper_params: Dict = default_specs.RECOGNITION_HYPER_PARAMS, progress_callback: Callable[[str, int], Callable[ [None], None]] = lambda string, length: lambda: None, message: Callable[[str], None] = lambda *args, **kwargs: None, output: str = 'model', spec: str = default_specs.RECOGNITION_SPEC, append: Optional[int] = None, load: Optional[str] = None, device: str = 'cpu', reorder: bool = True, training_data: Sequence[Dict] = None, evaluation_data: Sequence[Dict] = None, preload: Optional[bool] = None, threads: int = 1, load_hyper_parameters: bool = False, repolygonize: bool = False, force_binarization: bool = False, format_type: str = 'path', codec: Optional[Dict] = None, resize: str = 'fail', augment: bool = False): """ This is an ugly constructor that takes all the arguments from the command line driver, finagles the datasets, models, and hyperparameters correctly and returns a KrakenTrainer object. Setup parameters (load, training_data, evaluation_data, ....) are named, model hyperparameters (everything in kraken.lib.default_specs.RECOGNITION_HYPER_PARAMS) are in in the `hyper_params` argument. Args: hyper_params (dict): Hyperparameter dictionary containing all fields from kraken.lib.default_specs.RECOGNITION_HYPER_PARAMS progress_callback (Callable): Callback for progress reports on various computationally expensive processes. A human readable string and the process length is supplied. The callback has to return another function which will be executed after each step. message (Callable): Messaging printing method for above log but below warning level output, i.e. infos that should generally be shown to users. **kwargs: Setup parameters, i.e. CLI parameters of the train() command. Returns: A KrakenTrainer object. """ # load model if given. if a new model has to be created we need to do that # after data set initialization, otherwise to output size is still unknown. nn = None if load: logger.info(f'Loading existing model from {load} ') message(f'Loading existing model from {load} ', nl=False) nn = vgsl.TorchVGSLModel.load_model(load) if load_hyper_parameters: hyper_params.update(nn.hyper_params) nn.hyper_params = hyper_params message('\u2713', fg='green', nl=False) DatasetClass = GroundTruthDataset valid_norm = True if format_type and format_type != 'path': logger.info( f'Parsing {len(training_data)} XML files for training data') if repolygonize: message('Repolygonizing data') training_data = preparse_xml_data(training_data, format_type, repolygonize) evaluation_data = preparse_xml_data(evaluation_data, format_type, repolygonize) DatasetClass = PolygonGTDataset valid_norm = False elif format_type == 'path': if force_binarization: logger.warning( 'Forced binarization enabled in `path` mode. Will be ignored.' ) force_binarization = False if repolygonize: logger.warning( 'Repolygonization enabled in `path` mode. Will be ignored.' ) training_data = [{'image': im} for im in training_data] if evaluation_data: evaluation_data = [{'image': im} for im in evaluation_data] valid_norm = True # format_type is None. Determine training type from length of training data entry else: if len(training_data[0]) >= 4: DatasetClass = PolygonGTDataset valid_norm = False else: if force_binarization: logger.warning( 'Forced binarization enabled with box lines. Will be ignored.' ) force_binarization = False if repolygonize: logger.warning( 'Repolygonization enabled with box lines. Will be ignored.' ) # preparse input sizes from vgsl string to seed ground truth data set # sizes and dimension ordering. if not nn: spec = spec.strip() if spec[0] != '[' or spec[-1] != ']': raise click.BadOptionUsage( 'spec', 'VGSL spec {} not bracketed'.format(spec)) blocks = spec[1:-1].split(' ') m = re.match(r'(\d+),(\d+),(\d+),(\d+)', blocks[0]) if not m: raise click.BadOptionUsage('spec', f'Invalid input spec {blocks[0]}') batch, height, width, channels = [int(x) for x in m.groups()] else: batch, channels, height, width = nn.input try: transforms = generate_input_transforms(batch, height, width, channels, hyper_params['pad'], valid_norm, force_binarization) except KrakenInputException as e: raise click.BadOptionUsage('spec', str(e)) if len(training_data) > 2500 and not preload: logger.info( 'Disabling preloading for large (>2500) training data set. Enable by setting --preload parameter' ) preload = False # implicit preloading enabled for small data sets if preload is None: preload = True # set multiprocessing tensor sharing strategy if 'file_system' in torch.multiprocessing.get_all_sharing_strategies(): logger.debug( 'Setting multiprocessing tensor sharing strategy to file_system' ) torch.multiprocessing.set_sharing_strategy('file_system') gt_set = DatasetClass( normalization=hyper_params['normalization'], whitespace_normalization=hyper_params['normalize_whitespace'], reorder=reorder, im_transforms=transforms, preload=preload, augmentation=hyper_params['augment']) bar = progress_callback('Building training set', len(training_data)) for im in training_data: logger.debug(f'Adding line {im} to training set') try: gt_set.add(**im) bar() except FileNotFoundError as e: logger.warning(f'{e.strerror}: {e.filename}. Skipping.') except KrakenInputException as e: logger.warning(str(e)) val_set = DatasetClass( normalization=hyper_params['normalization'], whitespace_normalization=hyper_params['normalize_whitespace'], reorder=reorder, im_transforms=transforms, preload=preload) bar = progress_callback('Building validation set', len(evaluation_data)) for im in evaluation_data: logger.debug(f'Adding line {im} to validation set') try: val_set.add(**im) bar() except FileNotFoundError as e: logger.warning(f'{e.strerror}: {e.filename}. Skipping.') except KrakenInputException as e: logger.warning(str(e)) if len(gt_set._images) == 0: logger.error( 'No valid training data was provided to the train command. Please add valid XML or line data.' ) return None logger.info( f'Training set {len(gt_set._images)} lines, validation set {len(val_set._images)} lines, alphabet {len(gt_set.alphabet)} symbols' ) alpha_diff_only_train = set(gt_set.alphabet).difference( set(val_set.alphabet)) alpha_diff_only_val = set(val_set.alphabet).difference( set(gt_set.alphabet)) if alpha_diff_only_train: logger.warning( f'alphabet mismatch: chars in training set only: {alpha_diff_only_train} (not included in accuracy test during training)' ) if alpha_diff_only_val: logger.warning( f'alphabet mismatch: chars in validation set only: {alpha_diff_only_val} (not trained)' ) logger.info('grapheme\tcount') for k, v in sorted(gt_set.alphabet.items(), key=lambda x: x[1], reverse=True): char = make_printable(k) if char == k: char = '\t' + char logger.info(f'{char}\t{v}') logger.debug('Encoding training set') # use model codec when given if append: # is already loaded nn = cast(vgsl.TorchVGSLModel, nn) gt_set.encode(codec) message('Slicing and dicing model ', nl=False) # now we can create a new model spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label() + 1) logger.info( f'Appending {spec} to existing model {nn.spec} after {append}') nn.append(append, spec) nn.add_codec(gt_set.codec) message('\u2713', fg='green') logger.info(f'Assembled model spec: {nn.spec}') elif load: # is already loaded nn = cast(vgsl.TorchVGSLModel, nn) # prefer explicitly given codec over network codec if mode is 'both' codec = codec if (codec and resize == 'both') else nn.codec try: gt_set.encode(codec) except KrakenEncodeException: message('Network codec not compatible with training set') alpha_diff = set(gt_set.alphabet).difference( set(codec.c2l.keys())) if resize == 'fail': logger.error( f'Training data and model codec alphabets mismatch: {alpha_diff}' ) return None elif resize == 'add': message('Adding missing labels to network ', nl=False) logger.info( f'Resizing codec to include {len(alpha_diff)} new code points' ) codec.c2l.update({ k: [v] for v, k in enumerate(alpha_diff, start=codec.max_label() + 1) }) nn.add_codec(PytorchCodec(codec.c2l)) logger.info( f'Resizing last layer in network to {codec.max_label()+1} outputs' ) nn.resize_output(codec.max_label() + 1) gt_set.encode(nn.codec) message('\u2713', fg='green') elif resize == 'both': message('Fitting network exactly to training set ', nl=False) logger.info( f'Resizing network or given codec to {gt_set.alphabet} code sequences' ) gt_set.encode(None) ncodec, del_labels = codec.merge(gt_set.codec) logger.info( f'Deleting {len(del_labels)} output classes from network ({len(codec)-len(del_labels)} retained)' ) gt_set.encode(ncodec) nn.resize_output(ncodec.max_label() + 1, del_labels) message('\u2713', fg='green') else: logger.error(f'invalid resize parameter value {resize}') return None else: gt_set.encode(codec) logger.info( f'Creating new model {spec} with {gt_set.codec.max_label()+1} outputs' ) spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label() + 1) nn = vgsl.TorchVGSLModel(spec) # initialize weights message('Initializing model ', nl=False) nn.init_weights() nn.add_codec(gt_set.codec) # initialize codec message('\u2713', fg='green') if nn.one_channel_mode and gt_set.im_mode != nn.one_channel_mode: logger.warning( f'Neural network has been trained on mode {nn.one_channel_mode} images, training set contains mode {gt_set.im_mode} data. Consider setting `force_binarization`' ) if format_type != 'path' and nn.seg_type == 'bbox': logger.warning( 'Neural network has been trained on bounding box image information but training set is polygonal.' ) # half the number of data loading processes if device isn't cuda and we haven't enabled preloading if device == 'cpu' and not preload: loader_threads = threads // 2 else: loader_threads = threads train_loader = InfiniteDataLoader( gt_set, batch_size=hyper_params['batch_size'], shuffle=True, num_workers=loader_threads, pin_memory=True, collate_fn=collate_sequences) threads = max(threads - loader_threads, 1) # don't encode validation set as the alphabets may not match causing encoding failures val_set.no_encode() val_loader = DataLoader(val_set, batch_size=hyper_params['batch_size'], num_workers=loader_threads, pin_memory=True, collate_fn=collate_sequences) logger.debug('Constructing {} optimizer (lr: {}, momentum: {})'.format( hyper_params['optimizer'], hyper_params['lrate'], hyper_params['momentum'])) # set model type metadata field nn.model_type = 'recognition' # set mode to trainindg nn.train() # set number of OpenMP threads logger.debug(f'Set OpenMP threads to {threads}') nn.set_num_threads(threads) optim = getattr(torch.optim, hyper_params['optimizer'])(nn.nn.parameters(), lr=0) if 'seg_type' not in nn.user_metadata: nn.user_metadata[ 'seg_type'] = 'baselines' if format_type != 'path' else 'bbox' tr_it = TrainScheduler(optim) if hyper_params['schedule'] == '1cycle': add_1cycle(tr_it, int(len(gt_set) * hyper_params['epochs']), hyper_params['lrate'], hyper_params['momentum'], hyper_params['momentum'] - 0.10, hyper_params['weight_decay']) elif hyper_params['schedule'] == 'exponential': add_exponential_decay(tr_it, int(len(gt_set) * hyper_params['epochs']), len(gt_set), hyper_params['lrate'], 0.95, hyper_params['momentum'], hyper_params['weight_decay']) else: # constant learning rate scheduler tr_it.add_phase(1, 2 * (hyper_params['lrate'], ), 2 * (hyper_params['momentum'], ), hyper_params['weight_decay'], annealing_const) if hyper_params['quit'] == 'early': st_it = EarlyStopping(hyper_params['min_delta'], hyper_params['lag']) elif hyper_params['quit'] == 'dumb': st_it = EpochStopping(hyper_params['epochs'] - hyper_params['completed_epochs']) else: logger.error(f'Invalid training interruption scheme {quit}') return None trainer = cls(model=nn, optimizer=optim, device=device, filename_prefix=output, event_frequency=hyper_params['freq'], train_set=train_loader, val_set=val_loader, stopper=st_it) trainer.add_lr_scheduler(tr_it) return trainer
def __init__(self, nets: Dict[str, TorchSeqRecognizer], im: Image.Image, bounds: dict, pad: int = 16, bidi_reordering: bool = True, script_ignore: Optional[List[str]] = None) -> Generator[ocr_record, None, None]: """ Multi-model version of kraken.rpred.rpred. Takes a dictionary of ISO15924 script identifiers->models and an script-annotated segmentation to dynamically select appropriate models for these lines. Args: nets (dict): A dict mapping ISO15924 identifiers to TorchSegRecognizer objects. Recommended to be an defaultdict. im (PIL.Image.Image): Image to extract text from bounds (dict): A dictionary containing a 'boxes' entry with a list of lists of coordinates (script, (x0, y0, x1, y1)) of a text line in the image and an entry 'text_direction' containing 'horizontal-lr/rl/vertical-lr/rl'. pad (int): Extra blank padding to the left and right of text line bidi_reordering (bool): Reorder classes in the ocr_record according to the Unicode bidirectional algorithm for correct display. script_ignore (list): List of scripts to ignore during recognition Yields: An ocr_record containing the recognized text, absolute character positions, and confidence values for each character. Raises: KrakenInputException if the mapping between segmentation scripts and networks is incomplete. """ seg_types = set(recognizer.seg_type for recognizer in nets.values()) if ('type' in bounds and bounds['type'] not in seg_types) or len(seg_types) > 1: logger.warning('Recognizers with segmentation types {} will be ' 'applied to segmentation of type {}. This will likely result ' 'in severely degraded performace'.format(seg_types, bounds['type'] if 'type' in bounds else None)) one_channel_modes = set(recognizer.nn.one_channel_mode for recognizer in nets.values()) if '1' in one_channel_modes and len(one_channel_modes) > 1: raise KrakenInputException('Mixing binary and non-binary recognition models is not supported.') elif '1' in one_channel_modes and not is_bitonal(im): logger.warning('Running binary models on non-binary input image ' '(mode {}). This will result in severely degraded ' 'performance'.format(im.mode)) if 'type' in bounds and bounds['type'] == 'baselines': valid_norm = False self.len = len(bounds['lines']) self.seg_key = 'lines' self.next_iter = self._recognize_baseline_line self.line_iter = iter(bounds['lines']) scripts = [x['script'] for x in bounds['lines']] else: valid_norm = True self.len = len(bounds['boxes']) self.seg_key = 'boxes' self.next_iter = self._recognize_box_line self.line_iter = iter(bounds['boxes']) scripts = [x[0] for line in bounds['boxes'] for x in line] im_str = get_im_str(im) logger.info('Running {} multi-script recognizers on {} with {} lines'.format(len(nets), im_str, self.len)) miss = [script for script in scripts if not nets.get(script)] if miss and not isinstance(nets, defaultdict): raise KrakenInputException('Missing models for scripts {}'.format(set(miss))) # build dictionary for line preprocessing self.ts = {} for script in scripts: logger.debug('Loading line transforms for {}'.format(script)) network = nets[script] batch, channels, height, width = network.nn.input self.ts[script] = generate_input_transforms(batch, height, width, channels, pad, valid_norm) self.im = im self.nets = nets self.bidi_reordering = bidi_reordering self.pad = pad self.bounds = bounds self.script_ignore = script_ignore
def train(ctx, pad, output, spec, append, load, savefreq, report, quit, epochs, lag, min_delta, device, optimizer, lrate, momentum, weight_decay, schedule, partition, normalization, codec, resize, reorder, training_files, evaluation_files, preload, threads, ground_truth): """ Trains a model from image-text pairs. """ if not load and append: raise click.BadOptionUsage( 'append', 'append option requires loading an existing model') if resize != 'fail' and not load: raise click.BadOptionUsage( 'resize', 'resize option requires loading an existing model') import re import torch import shutil import numpy as np from torch.utils.data import DataLoader from kraken.lib import models, vgsl, train from kraken.lib.util import make_printable from kraken.lib.train import EarlyStopping, EpochStopping, TrainStopper, TrainScheduler, add_1cycle from kraken.lib.codec import PytorchCodec from kraken.lib.dataset import GroundTruthDataset, compute_error, generate_input_transforms logger.info('Building ground truth set from {} line images'.format( len(ground_truth) + len(training_files))) # load model if given. if a new model has to be created we need to do that # after data set initialization, otherwise to output size is still unknown. nn = None if load: logger.info('Loading existing model from {} '.format(load)) message('Loading model {}'.format(load), nl=False) nn = vgsl.TorchVGSLModel.load_model(load) message('\u2713', fg='green', nl=False) # preparse input sizes from vgsl string to seed ground truth data set # sizes and dimension ordering. if not nn: spec = spec.strip() if spec[0] != '[' or spec[-1] != ']': raise click.BadOptionUsage( 'spec', 'VGSL spec {} not bracketed'.format(spec)) blocks = spec[1:-1].split(' ') m = re.match(r'(\d+),(\d+),(\d+),(\d+)', blocks[0]) if not m: raise click.BadOptionUsage( 'spec', 'Invalid input spec {}'.format(blocks[0])) batch, height, width, channels = [int(x) for x in m.groups()] else: batch, channels, height, width = nn.input try: transforms = generate_input_transforms(batch, height, width, channels, pad) except KrakenInputException as e: raise click.BadOptionUsage('spec', str(e)) # disable automatic partition when given evaluation set explicitly if evaluation_files: partition = 1 ground_truth = list(ground_truth) # merge training_files into ground_truth list if training_files: ground_truth.extend(training_files) if len(ground_truth) == 0: raise click.UsageError( 'No training data was provided to the train command. Use `-t` or the `ground_truth` argument.' ) np.random.shuffle(ground_truth) if len(ground_truth) > 2500 and not preload: logger.info( 'Disabling preloading for large (>2500) training data set. Enable by setting --preload parameter' ) preload = False # implicit preloading enabled for small data sets if preload is None: preload = True tr_im = ground_truth[:int(len(ground_truth) * partition)] if evaluation_files: logger.debug('Using {} lines from explicit eval set'.format( len(evaluation_files))) te_im = evaluation_files else: te_im = ground_truth[int(len(ground_truth) * partition):] logger.debug('Taking {} lines from training for evaluation'.format( len(te_im))) gt_set = GroundTruthDataset(normalization=normalization, reorder=reorder, im_transforms=transforms, preload=preload) with log.progressbar(tr_im, label='Building training set') as bar: for im in bar: logger.debug('Adding line {} to training set'.format(im)) try: gt_set.add(im) except FileNotFoundError as e: logger.warning('{}: {}. Skipping.'.format( e.strerror, e.filename)) except KrakenInputException as e: logger.warning(str(e)) val_set = GroundTruthDataset(normalization=normalization, reorder=reorder, im_transforms=transforms, preload=preload) with log.progressbar(te_im, label='Building validation set') as bar: for im in bar: logger.debug('Adding line {} to validation set'.format(im)) try: val_set.add(im) except FileNotFoundError as e: logger.warning('{}: {}. Skipping.'.format( e.strerror, e.filename)) except KrakenInputException as e: logger.warning(str(e)) logger.info( 'Training set {} lines, validation set {} lines, alphabet {} symbols'. format(len(gt_set._images), len(val_set._images), len(gt_set.alphabet))) alpha_diff = set(gt_set.alphabet).symmetric_difference( set(val_set.alphabet)) if alpha_diff: logger.warn('alphabet mismatch {}'.format(alpha_diff)) logger.info('grapheme\tcount') for k, v in sorted(gt_set.alphabet.items(), key=lambda x: x[1], reverse=True): char = make_printable(k) if char == k: char = '\t' + char logger.info(u'{}\t{}'.format(char, v)) logger.debug('Encoding training set') # use model codec when given if append: # is already loaded nn = cast(vgsl.TorchVGSLModel, nn) gt_set.encode(codec) message('Slicing and dicing model ', nl=False) # now we can create a new model spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label() + 1) logger.info('Appending {} to existing model {} after {}'.format( spec, nn.spec, append)) nn.append(append, spec) nn.add_codec(gt_set.codec) message('\u2713', fg='green') logger.info('Assembled model spec: {}'.format(nn.spec)) elif load: # is already loaded nn = cast(vgsl.TorchVGSLModel, nn) # prefer explicitly given codec over network codec if mode is 'both' codec = codec if (codec and resize == 'both') else nn.codec try: gt_set.encode(codec) except KrakenEncodeException as e: message('Network codec not compatible with training set') alpha_diff = set(gt_set.alphabet).difference(set(codec.c2l.keys())) if resize == 'fail': logger.error( 'Training data and model codec alphabets mismatch: {}'. format(alpha_diff)) ctx.exit(code=1) elif resize == 'add': message('Adding missing labels to network ', nl=False) logger.info( 'Resizing codec to include {} new code points'.format( len(alpha_diff))) codec.c2l.update({ k: [v] for v, k in enumerate(alpha_diff, start=codec.max_label() + 1) }) nn.add_codec(PytorchCodec(codec.c2l)) logger.info( 'Resizing last layer in network to {} outputs'.format( codec.max_label() + 1)) nn.resize_output(codec.max_label() + 1) message('\u2713', fg='green') elif resize == 'both': message('Fitting network exactly to training set ', nl=False) logger.info( 'Resizing network or given codec to {} code sequences'. format(len(gt_set.alphabet))) gt_set.encode(None) ncodec, del_labels = codec.merge(gt_set.codec) logger.info( 'Deleting {} output classes from network ({} retained)'. format(len(del_labels), len(codec) - len(del_labels))) gt_set.encode(ncodec) nn.resize_output(ncodec.max_label() + 1, del_labels) message('\u2713', fg='green') else: raise click.BadOptionUsage( 'resize', 'Invalid resize value {}'.format(resize)) else: gt_set.encode(codec) logger.info('Creating new model {} with {} outputs'.format( spec, gt_set.codec.max_label() + 1)) spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label() + 1) nn = vgsl.TorchVGSLModel(spec) # initialize weights message('Initializing model ', nl=False) nn.init_weights() nn.add_codec(gt_set.codec) # initialize codec message('\u2713', fg='green') train_loader = DataLoader(gt_set, batch_size=1, shuffle=True, pin_memory=True) # don't encode validation set as the alphabets may not match causing encoding failures val_set.training_set = list(zip(val_set._images, val_set._gt)) logger.debug('Constructing {} optimizer (lr: {}, momentum: {})'.format( optimizer, lrate, momentum)) # set mode to trainindg nn.train() # set number of OpenMP threads logger.debug('Set OpenMP threads to {}'.format(threads)) nn.set_num_threads(threads) logger.debug('Moving model to device {}'.format(device)) rec = models.TorchSeqRecognizer(nn, train=True, device=device) optim = getattr(torch.optim, optimizer)(nn.nn.parameters(), lr=0) tr_it = TrainScheduler(optim) if schedule == '1cycle': add_1cycle(tr_it, epochs * len(gt_set), lrate, momentum, momentum - 0.10, weight_decay) else: # constant learning rate scheduler tr_it.add_phase(1, (lrate, lrate), (momentum, momentum), weight_decay, train.annealing_const) st_it = cast(TrainStopper, None) # type: TrainStopper if quit == 'early': st_it = EarlyStopping(train_loader, min_delta, lag) elif quit == 'dumb': st_it = EpochStopping(train_loader, epochs) else: raise click.BadOptionUsage( 'quit', 'Invalid training interruption scheme {}'.format(quit)) for epoch, loader in enumerate(st_it): with log.progressbar(label='epoch {}/{}'.format( epoch, epochs - 1 if epochs > 0 else '∞'), length=len(loader), show_pos=True) as bar: acc_loss = torch.tensor(0.0).to(device, non_blocking=True) for trial, (input, target) in enumerate(loader): tr_it.step() input = input.to(device, non_blocking=True) target = target.to(device, non_blocking=True) input = input.requires_grad_() o = nn.nn(input) # height should be 1 by now if o.size(2) != 1: raise KrakenInputException( 'Expected dimension 3 to be 1, actual {}'.format( o.size(2))) o = o.squeeze(2) optim.zero_grad() # NCW -> WNC loss = nn.criterion( o.permute(2, 0, 1), # type: ignore target, (o.size(2), ), (target.size(1), )) logger.info('trial {}'.format(trial)) if not torch.isinf(loss): loss.backward() optim.step() else: logger.debug('infinite loss in trial {}'.format(trial)) bar.update(1) if not epoch % savefreq: logger.info('Saving to {}_{}'.format(output, epoch)) try: nn.save_model('{}_{}.mlmodel'.format(output, epoch)) except Exception as e: logger.error('Saving model failed: {}'.format(str(e))) if not epoch % report: logger.debug('Starting evaluation run') nn.eval() chars, error = compute_error(rec, list(val_set)) nn.train() accuracy = (chars - error) / chars logger.info('Accuracy report ({}) {:0.4f} {} {}'.format( epoch, accuracy, chars, error)) message('Accuracy report ({}) {:0.4f} {} {}'.format( epoch, accuracy, chars, error)) st_it.update(accuracy) if quit == 'early': message('Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'. format(output, st_it.best_epoch, st_it.best_loss)) logger.info( 'Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'. format(output, st_it.best_epoch, st_it.best_loss)) shutil.copy('{}_{}.mlmodel'.format(output, st_it.best_epoch), '{}_best.mlmodel'.format(output))
def segment(im, text_direction: str = 'horizontal-lr', mask: Optional[np.array] = None, reading_order_fn: Callable = polygonal_reading_order, model=None, device: str = 'cpu'): """ Segments a page into text lines using the baseline segmenter. Segments a page into text lines and returns the polyline formed by each baseline and their estimated environment. Args: im (PIL.Image): An RGB image. text_direction (str): Ignored by the segmenter but kept for serialization. mask (PIL.Image): A bi-level mask image of the same size as `im` where 0-valued regions are ignored for segmentation purposes. Disables column detection. reading_order_fn (function): Function to determine the reading order. Has to accept a list of tuples (baselines, polygon) and a text direction (`lr` or `rl`). model (vgsl.TorchVGSLModel): A TorchVGSLModel containing a segmentation model. If none is given a default model will be loaded. device (str or torch.Device): The target device to run the neural network on. Returns: {'text_direction': '$dir', 'type': 'baseline', 'lines': [ {'baseline': [[x0, y0], [x1, y1], ..., [x_n, y_n]], 'boundary': [[x0, y0, x1, y1], ... [x_m, y_m]]}, {'baseline': [[x0, ...]], 'boundary': [[x0, ...]]} ] 'regions': [ {'region': [[x0, y0], [x1, y1], ..., [x_n, y_n]], 'type': 'image'}, {'region': [[x0, ...]], 'type': 'text'} ] }: A dictionary containing the text direction and under the key 'lines' a list of reading order sorted baselines (polylines) and their respective polygonal boundaries. The last and first point of each boundary polygon is connected. Raises: KrakenInputException if the input image is not binarized or the text direction is invalid. """ im_str = get_im_str(im) logger.info(f'Segmenting {im_str}') if model is None: logger.info('No segmentation model given. Loading default model.') model = vgsl.TorchVGSLModel.load_model(pkg_resources.resource_filename(__name__, 'blla.mlmodel')) if model.one_channel_mode == '1' and not is_bitonal(im): logger.warning('Running binary model on non-binary input image ' '(mode {}). This will result in severely degraded ' 'performance'.format(im.mode)) model.eval() model.to(device) if mask: if mask.mode != '1' and not is_bitonal(mask): logger.error('Mask is not bitonal') raise KrakenInputException('Mask is not bitonal') mask = mask.convert('1') if mask.size != im.size: logger.error('Mask size {mask.size} doesn\'t match image size {im.size}') raise KrakenInputException('Mask size {mask.size} doesn\'t match image size {im.size}') logger.info('Masking enabled in segmenter.') mask = pil2array(mask) batch, channels, height, width = model.input transforms = dataset.generate_input_transforms(batch, height, width, channels, 0, valid_norm=False) res_tf = tf.Compose(transforms.transforms[:3]) scal_im = res_tf(im).convert('L') with torch.no_grad(): logger.debug('Running network forward pass') o = model.nn(transforms(im).unsqueeze(0).to(device)) logger.debug('Upsampling network output') o = F.interpolate(o, size=scal_im.size[::-1]) o = o.squeeze().cpu().numpy() scale = np.divide(im.size, o.shape[:0:-1]) # postprocessing cls_map = model.user_metadata['class_mapping'] st_sep = cls_map['aux']['_start_separator'] end_sep = cls_map['aux']['_end_separator'] logger.info('Vectorizing baselines') baselines = [] regions = {} for bl_type, idx in cls_map['baselines'].items(): logger.debug(f'Vectorizing lines of type {bl_type}') baselines.extend([(bl_type,x) for x in vectorize_lines(o[(st_sep, end_sep, idx), :, :])]) logger.info('Vectorizing regions') for region_type, idx in cls_map['regions'].items(): logger.debug(f'Vectorizing lines of type {bl_type}') regions[region_type] = vectorize_regions(o[idx]) logger.debug('Polygonizing lines') lines = list(filter(lambda x: x[2] is not None, zip([x[0] for x in baselines], [x[1] for x in baselines], calculate_polygonal_environment(scal_im, [x[1] for x in baselines])))) logger.debug('Scaling vectorized lines') sc = scale_polygonal_lines([x[1:] for x in lines], scale) lines = list(zip([x[0] for x in lines], [x[0] for x in sc], [x[1] for x in sc])) logger.debug('Scaling vectorized regions') for reg_id, regs in regions.items(): regions[reg_id] = scale_regions(regs, scale) logger.debug('Reordering baselines') order_regs = [] for regs in regions.values(): order_regs.extend(regs) lines = reading_order_fn(lines=lines, regions=order_regs, text_direction=text_direction[-2:]) if 'class_mapping' in model.user_metadata and len(model.user_metadata['class_mapping']['baselines']) > 1: script_detection = True else: script_detection = False return {'text_direction': text_direction, 'type': 'baselines', 'lines': [{'script': bl_type, 'baseline': bl, 'boundary': pl} for bl_type, bl, pl in lines], 'regions': regions, 'script_detection': script_detection}
""" Produces semi-transparent neural segmenter output overlays """ import sys import torch from PIL import Image from kraken.lib import vgsl, dataset import torch.nn.functional as F from os.path import splitext model = vgsl.TorchVGSLModel.load_model(sys.argv[1]) model.eval() batch, channels, height, width = model.input transforms = dataset.generate_input_transforms(batch, height, width, channels, 0, valid_norm=False) imgs = sys.argv[2:] torch.set_num_threads(1) for img in imgs: print(img) im = Image.open(img) with torch.no_grad(): o = model.nn(transforms(im).unsqueeze(0)) o = F.interpolate(o, size=im.size[::-1]) o = o.squeeze().numpy() heat = Image.fromarray((o[1]*255).astype('uint8')) heat.save(splitext(img)[0] + '.heat.png') overlay = Image.new('RGBA', im.size, (0, 130, 200, 255)) Image.composite(overlay, im.convert('RGBA'), heat).save(splitext(img)[0] + '.overlay.png')
def rpred(network: TorchSeqRecognizer, im: Image.Image, bounds: dict, pad: int = 16, bidi_reordering: bool = True) -> Generator[ocr_record, None, None]: """ Uses a RNN to recognize text Args: network (kraken.lib.models.TorchSeqRecognizer): A TorchSegRecognizer object im (PIL.Image.Image): Image to extract text from bounds (dict): A dictionary containing a 'boxes' entry with a list of coordinates (x0, y0, x1, y1) of a text line in the image and an entry 'text_direction' containing 'horizontal-lr/rl/vertical-lr/rl'. pad (int): Extra blank padding to the left and right of text line. Auto-disabled when expected network inputs are incompatible with padding. bidi_reordering (bool): Reorder classes in the ocr_record according to the Unicode bidirectional algorithm for correct display. Yields: An ocr_record containing the recognized text, absolute character positions, and confidence values for each character. """ im_str = get_im_str(im) logger.info('Running recognizer on {} with {} lines'.format(im_str, len(bounds['boxes']))) logger.debug('Loading line transform') batch, channels, height, width = network.nn.input ts = generate_input_transforms(batch, height, width, channels, pad) for box, coords in extract_boxes(im, bounds): # check if boxes are non-zero in any dimension if sum(coords[::2]) == 0 or coords[3] - coords[1] == 0: logger.warning('bbox {} with zero dimension. Emitting empty record.'.format(coords)) yield ocr_record('', [], []) continue # try conversion into tensor try: line = ts(box) except Exception: yield ocr_record('', [], []) continue # check if line is non-zero if line.max() == line.min(): yield ocr_record('', [], []) continue preds = network.predict(line) # calculate recognized LSTM locations of characters # scale between network output and network input net_scale = line.shape[2]/network.outputs.shape[1] # scale between network input and original line in_scale = box.size[0]/(line.shape[2]-2*pad) def _scale_val(val, min_val, max_val): return int(round(min(max(((val*net_scale)-pad)*in_scale, min_val), max_val))) # XXX: fix bounding box calculation ocr_record for multi-codepoint labels. pred = ''.join(x[0] for x in preds) pos = [] conf = [] for _, start, end, c in preds: if bounds['text_direction'].startswith('horizontal'): xmin = coords[0] + _scale_val(start, 0, box.size[0]) xmax = coords[0] + _scale_val(end, 0, box.size[0]) pos.append((xmin, coords[1], xmax, coords[3])) else: ymin = coords[1] + _scale_val(start, 0, box.size[1]) ymax = coords[1] + _scale_val(start, 0, box.size[1]) pos.append((coords[0], ymin, coords[2], ymax)) conf.append(c) if bidi_reordering: logger.debug('BiDi reordering record.') yield bidi_record(ocr_record(pred, pos, conf)) else: logger.debug('Emitting raw record') yield ocr_record(pred, pos, conf)
def mm_rpred(nets: Dict[str, TorchSeqRecognizer], im: Image.Image, bounds: dict, pad: int = 16, bidi_reordering: bool = True, script_ignore: Optional[List[str]] = None) -> Generator[ocr_record, None, None]: """ Multi-model version of kraken.rpred.rpred. Takes a dictionary of ISO15924 script identifiers->models and an script-annotated segmentation to dynamically select appropriate models for these lines. Args: nets (dict): A dict mapping ISO15924 identifiers to TorchSegRecognizer objects. Recommended to be an defaultdict. im (PIL.Image.Image): Image to extract text from bounds (dict): A dictionary containing a 'boxes' entry with a list of lists of coordinates (script, (x0, y0, x1, y1)) of a text line in the image and an entry 'text_direction' containing 'horizontal-lr/rl/vertical-lr/rl'. pad (int): Extra blank padding to the left and right of text line bidi_reordering (bool): Reorder classes in the ocr_record according to the Unicode bidirectional algorithm for correct display. script_ignore (list): List of scripts to ignore during recognition Yields: An ocr_record containing the recognized text, absolute character positions, and confidence values for each character. Raises: KrakenInputException if the mapping between segmentation scripts and networks is incomplete. """ im_str = get_im_str(im) logger.info('Running {} multi-script recognizers on {} with {} lines'.format(len(nets), im_str, len(bounds['boxes']))) miss = [x[0] for x in bounds['boxes'] if not nets.get(x[0])] if miss: raise KrakenInputException('Missing models for scripts {}'.format(miss)) # build dictionary for line preprocessing ts = {} for script, network in nets.items(): logger.debug('Loading line transforms for {}'.format(script)) batch, channels, height, width = network.nn.input ts[script] = generate_input_transforms(batch, height, width, channels, pad) for line in bounds['boxes']: rec = ocr_record('', [], []) for script, (box, coords) in zip(map(lambda x: x[0], line), extract_boxes(im, {'text_direction': bounds['text_direction'], 'boxes': map(lambda x: x[1], line)})): # skip if script is set to ignore if script_ignore is not None and script in script_ignore: logger.info('Ignoring {} line segment.'.format(script)) continue # check if boxes are non-zero in any dimension if sum(coords[::2]) == 0 or coords[3] - coords[1] == 0: logger.warning('Run with zero dimension. Skipping.') continue # try conversion into tensor try: logger.debug('Preparing run.') line = ts[script](box) except Exception: logger.warning('Conversion of line {} failed. Skipping.'.format(coords)) yield ocr_record('', [], []) continue # check if line is non-zero if line.max() == line.min(): logger.warning('Empty run. Skipping.') yield ocr_record('', [], []) continue logger.debug('Forward pass with model {}'.format(script)) preds = nets[script].predict(line) # calculate recognized LSTM locations of characters logger.debug('Convert to absolute coordinates') scale = box.size[0]/(len(nets[script].outputs)-2 * pad) pred = ''.join(x[0] for x in preds) pos = [] conf = [] for _, start, end, c in preds: if bounds['text_direction'].startswith('horizontal'): xmin = coords[0] + int(max((start-pad)*scale, 0)) xmax = coords[0] + max(int(min((end-pad)*scale, coords[2]-coords[0])), 1) pos.append((xmin, coords[1], xmax, coords[3])) else: ymin = coords[1] + int(max((start-pad)*scale, 0)) ymax = coords[1] + max(int(min((end-pad)*scale, coords[3]-coords[1])), 1) pos.append((coords[0], ymin, coords[2], ymax)) conf.append(c) rec.prediction += pred rec.cuts.extend(pos) rec.confidences.extend(conf) if bidi_reordering: logger.debug('BiDi reordering record.') yield bidi_record(rec) else: logger.debug('Emitting raw record') yield rec