def render_report(model: str, chars: int, errors: int, char_confusions: Counter, scripts: Counter, insertions: Counter, deletions: int, substitutions: Counter) -> str: """ Renders an accuracy report. Args: model (str): Model name. errors (int): Number of errors on test set. char_confusions (dict): Dictionary mapping a tuple (gt, pred) to a number of occurrences. scripts (dict): Dictionary counting character per script. insertions (dict): Dictionary counting insertion operations per Unicode script deletions (int): Number of deletions substitutions (dict): Dictionary counting substitution operations per Unicode script. Returns: A string containing the rendered report. """ logger.info(f'Serializing report for {model}.') report = {'model': model, 'chars': chars, 'errors': errors, 'accuracy': (chars-errors)/chars * 100, 'insertions': sum(insertions.values()), 'deletions': deletions, 'substitutions': sum(substitutions.values()), 'scripts': sorted([{'script': k, 'count': v, 'errors': insertions[k] + substitutions[k], 'accuracy': 100 * (v-(insertions[k] + substitutions[k]))/v} for k, v in scripts.items()], key=lambda x: x['accuracy'], reverse=True), 'counts': sorted([{'correct': make_printable(k[0]), 'generated': make_printable(k[1]), 'errors': v} for k, v in char_confusions.items() if k[0] != k[1]], key=lambda x: x['errors'], reverse=True)} logger.debug('Initializing jinja environment.') env = Environment(loader=PackageLoader('kraken', 'templates'), trim_blocks=True, lstrip_blocks=True, autoescape=True) logger.debug('Retrieving template.') tmpl = env.get_template('report') logger.debug('Rendering data.') return tmpl.render(report=report)
def show(ctx, model_id): """ Retrieves model metadata from the repository. """ import unicodedata from kraken import repo from kraken.lib.util import make_printable, is_printable desc = repo.get_description(model_id) chars = [] combining = [] for char in sorted(desc['graphemes']): if not is_printable(char): combining.append(make_printable(char)) else: chars.append(char) message( 'name: {}\n\n{}\n\n{}\nscripts: {}\nalphabet: {} {}\nlicense: {}\nauthor: {} ({})\n{}' .format(desc['name'], desc['summary'], desc['description'], ' '.join(desc['script']), ''.join(chars), ', '.join(combining), desc['license'], desc['author'], desc['author-email'], desc['url'])) ctx.exit(0)
def show(ctx, model_id): """ Retrieves model metadata from the repository. """ import unicodedata from kraken import repo from kraken.lib.util import make_printable, is_printable desc = repo.get_description(model_id) chars = [] combining = [] for char in sorted(desc['graphemes']): if not is_printable(char): combining.append(make_printable(char)) else: chars.append(char) message( 'name: {}\n\n{}\n\n{}\nscripts: {}\nalphabet: {} {}\naccuracy: {:.2f}%\nlicense: {}\nauthor(s): {}\ndate: {}' .format(model_id, desc['summary'], desc['description'], ' '.join(desc['script']), ''.join(chars), ', '.join(combining), desc['accuracy'], desc['license']['id'], '; '.join(x['name'] for x in desc['creators']), desc['publication_date'])) ctx.exit(0)
def show(ctx, model_id): """ Retrieves model metadata from the repository. """ import unicodedata from kraken import repo from kraken.lib.util import make_printable, is_printable desc = repo.get_description(model_id) chars = [] combining = [] for char in sorted(desc['graphemes']): if not is_printable(char): combining.append(make_printable(char)) else: chars.append(char) message('name: {}\n\n{}\n\n{}\nscripts: {}\nalphabet: {} {}\naccuracy: {:.2f}%\nlicense: {}\nauthor(s): {}\ndate: {}'.format(model_id, desc['summary'], desc['description'], ' '.join(desc['script']), ''.join(chars), ', '.join(combining), desc['accuracy'], desc['license']['id'], '; '.join(x['name'] for x in desc['creators']), desc['publication_date'])) ctx.exit(0)
def line_generator(ctx, font, maxlines, encoding, normalization, renormalize, reorder, font_size, font_weight, language, max_length, strip, disable_degradation, alpha, beta, distort, distortion_sigma, legacy, output, text): """ Generates artificial text line training data. """ import errno import numpy as np from kraken import linegen from kraken.lib.util import make_printable lines: Set[str] = set() if not text: return with log.progressbar(text, label='Reading texts') as bar: for t in text: with click.open_file(t, encoding=encoding) as fp: logger.info('Reading {}'.format(t)) for l in fp: lines.add(l.rstrip('\r\n')) if normalization: lines = set([unicodedata.normalize(normalization, line) for line in lines]) if strip: lines = set([line.strip() for line in lines]) if max_length: lines = set([line for line in lines if len(line) < max_length]) logger.info('Read {} lines'.format(len(lines))) message('Read {} unique lines'.format(len(lines))) if maxlines and maxlines < len(lines): message('Sampling {} lines\t'.format(maxlines), nl=False) llist = list(lines) lines = set(llist[idx] for idx in np.random.randint(0, len(llist), maxlines)) message('\u2713', fg='green') try: os.makedirs(output) except OSError as e: if e.errno != errno.EEXIST: raise # calculate the alphabet and print it for verification purposes alphabet: Set[str] = set() for line in lines: alphabet.update(line) chars = [] combining = [] for char in sorted(alphabet): k = make_printable(char) if k != char: combining.append(k) else: chars.append(k) message('Σ (len: {})'.format(len(alphabet))) message('Symbols: {}'.format(''.join(chars))) if combining: message('Combining Characters: {}'.format(', '.join(combining))) lg = linegen.LineGenerator(font, font_size, font_weight, language) with log.progressbar(lines, label='Writing images') as bar: for idx, line in enumerate(bar): logger.info(line) try: if renormalize: im = lg.render_line(unicodedata.normalize(renormalize, line)) else: im = lg.render_line(line) except KrakenCairoSurfaceException as e: logger.info('{}: {} {}'.format(e.message, e.width, e.height)) continue if not disable_degradation and not legacy: im = linegen.degrade_line(im, alpha=alpha, beta=beta) im = linegen.distort_line(im, abs(np.random.normal(distort)), abs(np.random.normal(distortion_sigma))) elif legacy: im = linegen.ocropy_degrade(im) im.save('{}/{:06d}.png'.format(output, idx)) with open('{}/{:06d}.gt.txt'.format(output, idx), 'wb') as fp: if reorder: fp.write(get_display(line).encode('utf-8')) else: fp.write(line.encode('utf-8'))
def train(ctx, pad, output, spec, append, load, freq, quit, epochs, lag, min_delta, device, optimizer, lrate, momentum, weight_decay, schedule, partition, normalization, normalize_whitespace, codec, resize, reorder, training_files, evaluation_files, preload, threads, ground_truth): """ Trains a model from image-text pairs. """ if not load and append: raise click.BadOptionUsage('append', 'append option requires loading an existing model') if resize != 'fail' and not load: raise click.BadOptionUsage('resize', 'resize option requires loading an existing model') import re import torch import shutil import numpy as np from torch.utils.data import DataLoader from kraken.lib import models, vgsl, train from kraken.lib.util import make_printable from kraken.lib.train import EarlyStopping, EpochStopping, TrainStopper, TrainScheduler, add_1cycle from kraken.lib.codec import PytorchCodec from kraken.lib.dataset import GroundTruthDataset, generate_input_transforms logger.info('Building ground truth set from {} line images'.format(len(ground_truth) + len(training_files))) completed_epochs = 0 # load model if given. if a new model has to be created we need to do that # after data set initialization, otherwise to output size is still unknown. nn = None #hyper_fields = ['freq', 'quit', 'epochs', 'lag', 'min_delta', 'optimizer', 'lrate', 'momentum', 'weight_decay', 'schedule', 'partition', 'normalization', 'normalize_whitespace', 'reorder', 'preload', 'completed_epochs', 'output'] if load: logger.info('Loading existing model from {} '.format(load)) message('Loading existing model from {}'.format(load), nl=False) nn = vgsl.TorchVGSLModel.load_model(load) #if nn.user_metadata and load_hyper_parameters: # for param in hyper_fields: # if param in nn.user_metadata: # logger.info('Setting \'{}\' to \'{}\''.format(param, nn.user_metadata[param])) # message('Setting \'{}\' to \'{}\''.format(param, nn.user_metadata[param])) # locals()[param] = nn.user_metadata[param] message('\u2713', fg='green', nl=False) # preparse input sizes from vgsl string to seed ground truth data set # sizes and dimension ordering. if not nn: spec = spec.strip() if spec[0] != '[' or spec[-1] != ']': raise click.BadOptionUsage('spec', 'VGSL spec {} not bracketed'.format(spec)) blocks = spec[1:-1].split(' ') m = re.match(r'(\d+),(\d+),(\d+),(\d+)', blocks[0]) if not m: raise click.BadOptionUsage('spec', 'Invalid input spec {}'.format(blocks[0])) batch, height, width, channels = [int(x) for x in m.groups()] else: batch, channels, height, width = nn.input try: transforms = generate_input_transforms(batch, height, width, channels, pad) except KrakenInputException as e: raise click.BadOptionUsage('spec', str(e)) # disable automatic partition when given evaluation set explicitly if evaluation_files: partition = 1 ground_truth = list(ground_truth) # merge training_files into ground_truth list if training_files: ground_truth.extend(training_files) if len(ground_truth) == 0: raise click.UsageError('No training data was provided to the train command. Use `-t` or the `ground_truth` argument.') np.random.shuffle(ground_truth) if len(ground_truth) > 2500 and not preload: logger.info('Disabling preloading for large (>2500) training data set. Enable by setting --preload parameter') preload = False # implicit preloading enabled for small data sets if preload is None: preload = True tr_im = ground_truth[:int(len(ground_truth) * partition)] if evaluation_files: logger.debug('Using {} lines from explicit eval set'.format(len(evaluation_files))) te_im = evaluation_files else: te_im = ground_truth[int(len(ground_truth) * partition):] logger.debug('Taking {} lines from training for evaluation'.format(len(te_im))) # set multiprocessing tensor sharing strategy if 'file_system' in torch.multiprocessing.get_all_sharing_strategies(): logger.debug('Setting multiprocessing tensor sharing strategy to file_system') torch.multiprocessing.set_sharing_strategy('file_system') gt_set = GroundTruthDataset(normalization=normalization, whitespace_normalization=normalize_whitespace, reorder=reorder, im_transforms=transforms, preload=preload) with log.progressbar(tr_im, label='Building training set') as bar: for im in bar: logger.debug('Adding line {} to training set'.format(im)) try: gt_set.add(im) except FileNotFoundError as e: logger.warning('{}: {}. Skipping.'.format(e.strerror, e.filename)) except KrakenInputException as e: logger.warning(str(e)) val_set = GroundTruthDataset(normalization=normalization, whitespace_normalization=normalize_whitespace, reorder=reorder, im_transforms=transforms, preload=preload) with log.progressbar(te_im, label='Building validation set') as bar: for im in bar: logger.debug('Adding line {} to validation set'.format(im)) try: val_set.add(im) except FileNotFoundError as e: logger.warning('{}: {}. Skipping.'.format(e.strerror, e.filename)) except KrakenInputException as e: logger.warning(str(e)) logger.info('Training set {} lines, validation set {} lines, alphabet {} symbols'.format(len(gt_set._images), len(val_set._images), len(gt_set.alphabet))) alpha_diff_only_train = set(gt_set.alphabet).difference(set(val_set.alphabet)) alpha_diff_only_val = set(val_set.alphabet).difference(set(gt_set.alphabet)) if alpha_diff_only_train: logger.warning('alphabet mismatch: chars in training set only: {} (not included in accuracy test during training)'.format(alpha_diff_only_train)) if alpha_diff_only_val: logger.warning('alphabet mismatch: chars in validation set only: {} (not trained)'.format(alpha_diff_only_val)) logger.info('grapheme\tcount') for k, v in sorted(gt_set.alphabet.items(), key=lambda x: x[1], reverse=True): char = make_printable(k) if char == k: char = '\t' + char logger.info(u'{}\t{}'.format(char, v)) logger.debug('Encoding training set') # use model codec when given if append: # is already loaded nn = cast(vgsl.TorchVGSLModel, nn) gt_set.encode(codec) message('Slicing and dicing model ', nl=False) # now we can create a new model spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label()+1) logger.info('Appending {} to existing model {} after {}'.format(spec, nn.spec, append)) nn.append(append, spec) nn.add_codec(gt_set.codec) message('\u2713', fg='green') logger.info('Assembled model spec: {}'.format(nn.spec)) elif load: # is already loaded nn = cast(vgsl.TorchVGSLModel, nn) # prefer explicitly given codec over network codec if mode is 'both' codec = codec if (codec and resize == 'both') else nn.codec try: gt_set.encode(codec) except KrakenEncodeException as e: message('Network codec not compatible with training set') alpha_diff = set(gt_set.alphabet).difference(set(codec.c2l.keys())) if resize == 'fail': logger.error('Training data and model codec alphabets mismatch: {}'.format(alpha_diff)) ctx.exit(code=1) elif resize == 'add': message('Adding missing labels to network ', nl=False) logger.info('Resizing codec to include {} new code points'.format(len(alpha_diff))) codec.c2l.update({k: [v] for v, k in enumerate(alpha_diff, start=codec.max_label()+1)}) nn.add_codec(PytorchCodec(codec.c2l)) logger.info('Resizing last layer in network to {} outputs'.format(codec.max_label()+1)) nn.resize_output(codec.max_label()+1) gt_set.encode(nn.codec) message('\u2713', fg='green') elif resize == 'both': message('Fitting network exactly to training set ', nl=False) logger.info('Resizing network or given codec to {} code sequences'.format(len(gt_set.alphabet))) gt_set.encode(None) ncodec, del_labels = codec.merge(gt_set.codec) logger.info('Deleting {} output classes from network ({} retained)'.format(len(del_labels), len(codec)-len(del_labels))) gt_set.encode(ncodec) nn.resize_output(ncodec.max_label()+1, del_labels) message('\u2713', fg='green') else: raise click.BadOptionUsage('resize', 'Invalid resize value {}'.format(resize)) else: gt_set.encode(codec) logger.info('Creating new model {} with {} outputs'.format(spec, gt_set.codec.max_label()+1)) spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label()+1) nn = vgsl.TorchVGSLModel(spec) # initialize weights message('Initializing model ', nl=False) nn.init_weights() nn.add_codec(gt_set.codec) # initialize codec message('\u2713', fg='green') # half the number of data loading processes if device isn't cuda and we haven't enabled preloading if device == 'cpu' and not preload: loader_threads = threads // 2 else: loader_threads = threads train_loader = DataLoader(gt_set, batch_size=1, shuffle=True, num_workers=loader_threads, pin_memory=True) threads -= loader_threads # don't encode validation set as the alphabets may not match causing encoding failures val_set.training_set = list(zip(val_set._images, val_set._gt)) logger.debug('Constructing {} optimizer (lr: {}, momentum: {})'.format(optimizer, lrate, momentum)) # set mode to trainindg nn.train() # set number of OpenMP threads logger.debug('Set OpenMP threads to {}'.format(threads)) nn.set_num_threads(threads) logger.debug('Moving model to device {}'.format(device)) optim = getattr(torch.optim, optimizer)(nn.nn.parameters(), lr=0) if 'accuracy' not in nn.user_metadata: nn.user_metadata['accuracy'] = [] tr_it = TrainScheduler(optim) if schedule == '1cycle': add_1cycle(tr_it, int(len(gt_set) * epochs), lrate, momentum, momentum - 0.10, weight_decay) else: # constant learning rate scheduler tr_it.add_phase(1, (lrate, lrate), (momentum, momentum), weight_decay, train.annealing_const) if quit == 'early': st_it = EarlyStopping(min_delta, lag) elif quit == 'dumb': st_it = EpochStopping(epochs - completed_epochs) else: raise click.BadOptionUsage('quit', 'Invalid training interruption scheme {}'.format(quit)) #for param in hyper_fields: # logger.debug('Setting \'{}\' to \'{}\' in model metadata'.format(param, locals()[param])) # nn.user_metadata[param] = locals()[param] trainer = train.KrakenTrainer(model=nn, optimizer=optim, device=device, filename_prefix=output, event_frequency=freq, train_set=train_loader, val_set=val_set, stopper=st_it) trainer.add_lr_scheduler(tr_it) with log.progressbar(label='stage {}/{}'.format(1, trainer.stopper.epochs if trainer.stopper.epochs > 0 else '∞'), length=trainer.event_it, show_pos=True) as bar: def _draw_progressbar(): bar.update(1) def _print_eval(epoch, accuracy, chars, error): message('Accuracy report ({}) {:0.4f} {} {}'.format(epoch, accuracy, chars, error)) # reset progress bar bar.label = 'stage {}/{}'.format(epoch+1, trainer.stopper.epochs if trainer.stopper.epochs > 0 else '∞') bar.pos = 0 bar.finished = False trainer.run(_print_eval, _draw_progressbar) if quit == 'early': message('Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'.format(output, trainer.stopper.best_epoch, trainer.stopper.best_loss)) logger.info('Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'.format(output, trainer.stopper.best_epoch, trainer.stopper.best_loss)) shutil.copy('{}_{}.mlmodel'.format(output, trainer.stopper.best_epoch), '{}_best.mlmodel'.format(output))
def train(ctx, pad, output, spec, append, load, freq, quit, epochs, lag, min_delta, device, optimizer, lrate, momentum, weight_decay, schedule, partition, normalization, normalize_whitespace, codec, resize, reorder, training_files, evaluation_files, preload, threads, ground_truth): """ Trains a model from image-text pairs. """ if not load and append: raise click.BadOptionUsage( 'append', 'append option requires loading an existing model') if resize != 'fail' and not load: raise click.BadOptionUsage( 'resize', 'resize option requires loading an existing model') import re import torch import shutil import numpy as np from torch.utils.data import DataLoader from kraken.lib import models, vgsl, train from kraken.lib.util import make_printable from kraken.lib.train import EarlyStopping, EpochStopping, TrainStopper, TrainScheduler, add_1cycle from kraken.lib.codec import PytorchCodec from kraken.lib.dataset import GroundTruthDataset, generate_input_transforms logger.info('Building ground truth set from {} line images'.format( len(ground_truth) + len(training_files))) completed_epochs = 0 # load model if given. if a new model has to be created we need to do that # after data set initialization, otherwise to output size is still unknown. nn = None #hyper_fields = ['freq', 'quit', 'epochs', 'lag', 'min_delta', 'optimizer', 'lrate', 'momentum', 'weight_decay', 'schedule', 'partition', 'normalization', 'normalize_whitespace', 'reorder', 'preload', 'completed_epochs', 'output'] if load: logger.info('Loading existing model from {} '.format(load)) message('Loading existing model from {}'.format(load), nl=False) nn = vgsl.TorchVGSLModel.load_model(load) #if nn.user_metadata and load_hyper_parameters: # for param in hyper_fields: # if param in nn.user_metadata: # logger.info('Setting \'{}\' to \'{}\''.format(param, nn.user_metadata[param])) # message('Setting \'{}\' to \'{}\''.format(param, nn.user_metadata[param])) # locals()[param] = nn.user_metadata[param] message('\u2713', fg='green', nl=False) # preparse input sizes from vgsl string to seed ground truth data set # sizes and dimension ordering. if not nn: spec = spec.strip() if spec[0] != '[' or spec[-1] != ']': raise click.BadOptionUsage( 'spec', 'VGSL spec {} not bracketed'.format(spec)) blocks = spec[1:-1].split(' ') m = re.match(r'(\d+),(\d+),(\d+),(\d+)', blocks[0]) if not m: raise click.BadOptionUsage( 'spec', 'Invalid input spec {}'.format(blocks[0])) batch, height, width, channels = [int(x) for x in m.groups()] else: batch, channels, height, width = nn.input try: transforms = generate_input_transforms(batch, height, width, channels, pad) except KrakenInputException as e: raise click.BadOptionUsage('spec', str(e)) # disable automatic partition when given evaluation set explicitly if evaluation_files: partition = 1 ground_truth = list(ground_truth) # merge training_files into ground_truth list if training_files: ground_truth.extend(training_files) if len(ground_truth) == 0: raise click.UsageError( 'No training data was provided to the train command. Use `-t` or the `ground_truth` argument.' ) np.random.shuffle(ground_truth) if len(ground_truth) > 2500 and not preload: logger.info( 'Disabling preloading for large (>2500) training data set. Enable by setting --preload parameter' ) preload = False # implicit preloading enabled for small data sets if preload is None: preload = True tr_im = ground_truth[:int(len(ground_truth) * partition)] if evaluation_files: logger.debug('Using {} lines from explicit eval set'.format( len(evaluation_files))) te_im = evaluation_files else: te_im = ground_truth[int(len(ground_truth) * partition):] logger.debug('Taking {} lines from training for evaluation'.format( len(te_im))) # set multiprocessing tensor sharing strategy if 'file_system' in torch.multiprocessing.get_all_sharing_strategies(): logger.debug( 'Setting multiprocessing tensor sharing strategy to file_system') torch.multiprocessing.set_sharing_strategy('file_system') gt_set = GroundTruthDataset(normalization=normalization, whitespace_normalization=normalize_whitespace, reorder=reorder, im_transforms=transforms, preload=preload) with log.progressbar(tr_im, label='Building training set') as bar: for im in bar: logger.debug('Adding line {} to training set'.format(im)) try: gt_set.add(im) except FileNotFoundError as e: logger.warning('{}: {}. Skipping.'.format( e.strerror, e.filename)) except KrakenInputException as e: logger.warning(str(e)) val_set = GroundTruthDataset(normalization=normalization, whitespace_normalization=normalize_whitespace, reorder=reorder, im_transforms=transforms, preload=preload) with log.progressbar(te_im, label='Building validation set') as bar: for im in bar: logger.debug('Adding line {} to validation set'.format(im)) try: val_set.add(im) except FileNotFoundError as e: logger.warning('{}: {}. Skipping.'.format( e.strerror, e.filename)) except KrakenInputException as e: logger.warning(str(e)) logger.info( 'Training set {} lines, validation set {} lines, alphabet {} symbols'. format(len(gt_set._images), len(val_set._images), len(gt_set.alphabet))) alpha_diff = set(gt_set.alphabet).symmetric_difference( set(val_set.alphabet)) if alpha_diff: logger.warn('alphabet mismatch {}'.format(alpha_diff)) logger.info('grapheme\tcount') for k, v in sorted(gt_set.alphabet.items(), key=lambda x: x[1], reverse=True): char = make_printable(k) if char == k: char = '\t' + char logger.info(u'{}\t{}'.format(char, v)) logger.debug('Encoding training set') # use model codec when given if append: # is already loaded nn = cast(vgsl.TorchVGSLModel, nn) gt_set.encode(codec) message('Slicing and dicing model ', nl=False) # now we can create a new model spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label() + 1) logger.info('Appending {} to existing model {} after {}'.format( spec, nn.spec, append)) nn.append(append, spec) nn.add_codec(gt_set.codec) message('\u2713', fg='green') logger.info('Assembled model spec: {}'.format(nn.spec)) elif load: # is already loaded nn = cast(vgsl.TorchVGSLModel, nn) # prefer explicitly given codec over network codec if mode is 'both' codec = codec if (codec and resize == 'both') else nn.codec try: gt_set.encode(codec) except KrakenEncodeException as e: message('Network codec not compatible with training set') alpha_diff = set(gt_set.alphabet).difference(set(codec.c2l.keys())) if resize == 'fail': logger.error( 'Training data and model codec alphabets mismatch: {}'. format(alpha_diff)) ctx.exit(code=1) elif resize == 'add': message('Adding missing labels to network ', nl=False) logger.info( 'Resizing codec to include {} new code points'.format( len(alpha_diff))) codec.c2l.update({ k: [v] for v, k in enumerate(alpha_diff, start=codec.max_label() + 1) }) nn.add_codec(PytorchCodec(codec.c2l)) logger.info( 'Resizing last layer in network to {} outputs'.format( codec.max_label() + 1)) nn.resize_output(codec.max_label() + 1) gt_set.encode(nn.codec) message('\u2713', fg='green') elif resize == 'both': message('Fitting network exactly to training set ', nl=False) logger.info( 'Resizing network or given codec to {} code sequences'. format(len(gt_set.alphabet))) gt_set.encode(None) ncodec, del_labels = codec.merge(gt_set.codec) logger.info( 'Deleting {} output classes from network ({} retained)'. format(len(del_labels), len(codec) - len(del_labels))) gt_set.encode(ncodec) nn.resize_output(ncodec.max_label() + 1, del_labels) message('\u2713', fg='green') else: raise click.BadOptionUsage( 'resize', 'Invalid resize value {}'.format(resize)) else: gt_set.encode(codec) logger.info('Creating new model {} with {} outputs'.format( spec, gt_set.codec.max_label() + 1)) spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label() + 1) nn = vgsl.TorchVGSLModel(spec) # initialize weights message('Initializing model ', nl=False) nn.init_weights() nn.add_codec(gt_set.codec) # initialize codec message('\u2713', fg='green') # half the number of data loading processes if device isn't cuda and we haven't enabled preloading if device == 'cpu' and not preload: loader_threads = threads // 2 else: loader_threads = threads train_loader = DataLoader(gt_set, batch_size=1, shuffle=True, num_workers=loader_threads, pin_memory=True) threads -= loader_threads # don't encode validation set as the alphabets may not match causing encoding failures val_set.training_set = list(zip(val_set._images, val_set._gt)) logger.debug('Constructing {} optimizer (lr: {}, momentum: {})'.format( optimizer, lrate, momentum)) # set mode to trainindg nn.train() # set number of OpenMP threads logger.debug('Set OpenMP threads to {}'.format(threads)) nn.set_num_threads(threads) logger.debug('Moving model to device {}'.format(device)) optim = getattr(torch.optim, optimizer)(nn.nn.parameters(), lr=0) if 'accuracy' not in nn.user_metadata: nn.user_metadata['accuracy'] = [] tr_it = TrainScheduler(optim) if schedule == '1cycle': add_1cycle(tr_it, int(len(gt_set) * epochs), lrate, momentum, momentum - 0.10, weight_decay) else: # constant learning rate scheduler tr_it.add_phase(1, (lrate, lrate), (momentum, momentum), weight_decay, train.annealing_const) if quit == 'early': st_it = EarlyStopping(min_delta, lag) elif quit == 'dumb': st_it = EpochStopping(epochs - completed_epochs) else: raise click.BadOptionUsage( 'quit', 'Invalid training interruption scheme {}'.format(quit)) #for param in hyper_fields: # logger.debug('Setting \'{}\' to \'{}\' in model metadata'.format(param, locals()[param])) # nn.user_metadata[param] = locals()[param] trainer = train.KrakenTrainer(model=nn, optimizer=optim, device=device, filename_prefix=output, event_frequency=freq, train_set=train_loader, val_set=val_set, stopper=st_it) trainer.add_lr_scheduler(tr_it) with log.progressbar(label='stage {}/{}'.format( 1, trainer.stopper.epochs if trainer.stopper.epochs > 0 else '∞'), length=trainer.event_it, show_pos=True) as bar: def _draw_progressbar(): bar.update(1) def _print_eval(epoch, accuracy, chars, error): message('Accuracy report ({}) {:0.4f} {} {}'.format( epoch, accuracy, chars, error)) # reset progress bar bar.label = 'stage {}/{}'.format( epoch + 1, trainer.stopper.epochs if trainer.stopper.epochs > 0 else '∞') bar.pos = 0 bar.finished = False trainer.run(_print_eval, _draw_progressbar) if quit == 'early': message('Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'. format(output, trainer.stopper.best_epoch, trainer.stopper.best_loss)) logger.info( 'Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'. format(output, trainer.stopper.best_epoch, trainer.stopper.best_loss)) shutil.copy('{}_{}.mlmodel'.format(output, trainer.stopper.best_epoch), '{}_best.mlmodel'.format(output))
def line_generator(ctx, font, maxlines, encoding, normalization, renormalize, reorder, font_size, font_weight, language, max_length, strip, disable_degradation, alpha, beta, distort, distortion_sigma, legacy, output, text): """ Generates artificial text line training data. """ import errno import numpy as np from kraken import linegen from kraken.lib.util import make_printable lines: Set[str] = set() if not text: return with log.progressbar(text, label='Reading texts') as bar: for t in text: with click.open_file(t, encoding=encoding) as fp: logger.info('Reading {}'.format(t)) for l in fp: lines.add(l.rstrip('\r\n')) if normalization: lines = set( [unicodedata.normalize(normalization, line) for line in lines]) if strip: lines = set([line.strip() for line in lines]) if max_length: lines = set([line for line in lines if len(line) < max_length]) logger.info('Read {} lines'.format(len(lines))) message('Read {} unique lines'.format(len(lines))) if maxlines and maxlines < len(lines): message('Sampling {} lines\t'.format(maxlines), nl=False) llist = list(lines) lines = set(llist[idx] for idx in np.random.randint(0, len(llist), maxlines)) message('\u2713', fg='green') try: os.makedirs(output) except OSError as e: if e.errno != errno.EEXIST: raise # calculate the alphabet and print it for verification purposes alphabet: Set[str] = set() for line in lines: alphabet.update(line) chars = [] combining = [] for char in sorted(alphabet): k = make_printable(char) if k != char: combining.append(k) else: chars.append(k) message('Σ (len: {})'.format(len(alphabet))) message('Symbols: {}'.format(''.join(chars))) if combining: message('Combining Characters: {}'.format(', '.join(combining))) lg = linegen.LineGenerator(font, font_size, font_weight, language) with log.progressbar(lines, label='Writing images') as bar: for idx, line in enumerate(bar): logger.info(line) try: if renormalize: im = lg.render_line( unicodedata.normalize(renormalize, line)) else: im = lg.render_line(line) except KrakenCairoSurfaceException as e: logger.info('{}: {} {}'.format(e.message, e.width, e.height)) continue if not disable_degradation and not legacy: im = linegen.degrade_line(im, alpha=alpha, beta=beta) im = linegen.distort_line( im, abs(np.random.normal(distort)), abs(np.random.normal(distortion_sigma))) elif legacy: im = linegen.ocropy_degrade(im) im.save('{}/{:06d}.png'.format(output, idx)) with open('{}/{:06d}.gt.txt'.format(output, idx), 'wb') as fp: if reorder: fp.write(get_display(line).encode('utf-8')) else: fp.write(line.encode('utf-8'))
def recognition_train_gen( cls, hyper_params: Dict = default_specs.RECOGNITION_HYPER_PARAMS, progress_callback: Callable[[str, int], Callable[ [None], None]] = lambda string, length: lambda: None, message: Callable[[str], None] = lambda *args, **kwargs: None, output: str = 'model', spec: str = default_specs.RECOGNITION_SPEC, append: Optional[int] = None, load: Optional[str] = None, device: str = 'cpu', reorder: bool = True, training_data: Sequence[Dict] = None, evaluation_data: Sequence[Dict] = None, preload: Optional[bool] = None, threads: int = 1, load_hyper_parameters: bool = False, repolygonize: bool = False, force_binarization: bool = False, format_type: str = 'path', codec: Optional[Dict] = None, resize: str = 'fail', augment: bool = False): """ This is an ugly constructor that takes all the arguments from the command line driver, finagles the datasets, models, and hyperparameters correctly and returns a KrakenTrainer object. Setup parameters (load, training_data, evaluation_data, ....) are named, model hyperparameters (everything in kraken.lib.default_specs.RECOGNITION_HYPER_PARAMS) are in in the `hyper_params` argument. Args: hyper_params (dict): Hyperparameter dictionary containing all fields from kraken.lib.default_specs.RECOGNITION_HYPER_PARAMS progress_callback (Callable): Callback for progress reports on various computationally expensive processes. A human readable string and the process length is supplied. The callback has to return another function which will be executed after each step. message (Callable): Messaging printing method for above log but below warning level output, i.e. infos that should generally be shown to users. **kwargs: Setup parameters, i.e. CLI parameters of the train() command. Returns: A KrakenTrainer object. """ # load model if given. if a new model has to be created we need to do that # after data set initialization, otherwise to output size is still unknown. nn = None if load: logger.info(f'Loading existing model from {load} ') message(f'Loading existing model from {load} ', nl=False) nn = vgsl.TorchVGSLModel.load_model(load) if load_hyper_parameters: hyper_params.update(nn.hyper_params) nn.hyper_params = hyper_params message('\u2713', fg='green', nl=False) DatasetClass = GroundTruthDataset valid_norm = True if format_type and format_type != 'path': logger.info( f'Parsing {len(training_data)} XML files for training data') if repolygonize: message('Repolygonizing data') training_data = preparse_xml_data(training_data, format_type, repolygonize) evaluation_data = preparse_xml_data(evaluation_data, format_type, repolygonize) DatasetClass = PolygonGTDataset valid_norm = False elif format_type == 'path': if force_binarization: logger.warning( 'Forced binarization enabled in `path` mode. Will be ignored.' ) force_binarization = False if repolygonize: logger.warning( 'Repolygonization enabled in `path` mode. Will be ignored.' ) training_data = [{'image': im} for im in training_data] if evaluation_data: evaluation_data = [{'image': im} for im in evaluation_data] valid_norm = True # format_type is None. Determine training type from length of training data entry else: if len(training_data[0]) >= 4: DatasetClass = PolygonGTDataset valid_norm = False else: if force_binarization: logger.warning( 'Forced binarization enabled with box lines. Will be ignored.' ) force_binarization = False if repolygonize: logger.warning( 'Repolygonization enabled with box lines. Will be ignored.' ) # preparse input sizes from vgsl string to seed ground truth data set # sizes and dimension ordering. if not nn: spec = spec.strip() if spec[0] != '[' or spec[-1] != ']': raise click.BadOptionUsage( 'spec', 'VGSL spec {} not bracketed'.format(spec)) blocks = spec[1:-1].split(' ') m = re.match(r'(\d+),(\d+),(\d+),(\d+)', blocks[0]) if not m: raise click.BadOptionUsage('spec', f'Invalid input spec {blocks[0]}') batch, height, width, channels = [int(x) for x in m.groups()] else: batch, channels, height, width = nn.input try: transforms = generate_input_transforms(batch, height, width, channels, hyper_params['pad'], valid_norm, force_binarization) except KrakenInputException as e: raise click.BadOptionUsage('spec', str(e)) if len(training_data) > 2500 and not preload: logger.info( 'Disabling preloading for large (>2500) training data set. Enable by setting --preload parameter' ) preload = False # implicit preloading enabled for small data sets if preload is None: preload = True # set multiprocessing tensor sharing strategy if 'file_system' in torch.multiprocessing.get_all_sharing_strategies(): logger.debug( 'Setting multiprocessing tensor sharing strategy to file_system' ) torch.multiprocessing.set_sharing_strategy('file_system') gt_set = DatasetClass( normalization=hyper_params['normalization'], whitespace_normalization=hyper_params['normalize_whitespace'], reorder=reorder, im_transforms=transforms, preload=preload, augmentation=hyper_params['augment']) bar = progress_callback('Building training set', len(training_data)) for im in training_data: logger.debug(f'Adding line {im} to training set') try: gt_set.add(**im) bar() except FileNotFoundError as e: logger.warning(f'{e.strerror}: {e.filename}. Skipping.') except KrakenInputException as e: logger.warning(str(e)) val_set = DatasetClass( normalization=hyper_params['normalization'], whitespace_normalization=hyper_params['normalize_whitespace'], reorder=reorder, im_transforms=transforms, preload=preload) bar = progress_callback('Building validation set', len(evaluation_data)) for im in evaluation_data: logger.debug(f'Adding line {im} to validation set') try: val_set.add(**im) bar() except FileNotFoundError as e: logger.warning(f'{e.strerror}: {e.filename}. Skipping.') except KrakenInputException as e: logger.warning(str(e)) if len(gt_set._images) == 0: logger.error( 'No valid training data was provided to the train command. Please add valid XML or line data.' ) return None logger.info( f'Training set {len(gt_set._images)} lines, validation set {len(val_set._images)} lines, alphabet {len(gt_set.alphabet)} symbols' ) alpha_diff_only_train = set(gt_set.alphabet).difference( set(val_set.alphabet)) alpha_diff_only_val = set(val_set.alphabet).difference( set(gt_set.alphabet)) if alpha_diff_only_train: logger.warning( f'alphabet mismatch: chars in training set only: {alpha_diff_only_train} (not included in accuracy test during training)' ) if alpha_diff_only_val: logger.warning( f'alphabet mismatch: chars in validation set only: {alpha_diff_only_val} (not trained)' ) logger.info('grapheme\tcount') for k, v in sorted(gt_set.alphabet.items(), key=lambda x: x[1], reverse=True): char = make_printable(k) if char == k: char = '\t' + char logger.info(f'{char}\t{v}') logger.debug('Encoding training set') # use model codec when given if append: # is already loaded nn = cast(vgsl.TorchVGSLModel, nn) gt_set.encode(codec) message('Slicing and dicing model ', nl=False) # now we can create a new model spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label() + 1) logger.info( f'Appending {spec} to existing model {nn.spec} after {append}') nn.append(append, spec) nn.add_codec(gt_set.codec) message('\u2713', fg='green') logger.info(f'Assembled model spec: {nn.spec}') elif load: # is already loaded nn = cast(vgsl.TorchVGSLModel, nn) # prefer explicitly given codec over network codec if mode is 'both' codec = codec if (codec and resize == 'both') else nn.codec try: gt_set.encode(codec) except KrakenEncodeException: message('Network codec not compatible with training set') alpha_diff = set(gt_set.alphabet).difference( set(codec.c2l.keys())) if resize == 'fail': logger.error( f'Training data and model codec alphabets mismatch: {alpha_diff}' ) return None elif resize == 'add': message('Adding missing labels to network ', nl=False) logger.info( f'Resizing codec to include {len(alpha_diff)} new code points' ) codec.c2l.update({ k: [v] for v, k in enumerate(alpha_diff, start=codec.max_label() + 1) }) nn.add_codec(PytorchCodec(codec.c2l)) logger.info( f'Resizing last layer in network to {codec.max_label()+1} outputs' ) nn.resize_output(codec.max_label() + 1) gt_set.encode(nn.codec) message('\u2713', fg='green') elif resize == 'both': message('Fitting network exactly to training set ', nl=False) logger.info( f'Resizing network or given codec to {gt_set.alphabet} code sequences' ) gt_set.encode(None) ncodec, del_labels = codec.merge(gt_set.codec) logger.info( f'Deleting {len(del_labels)} output classes from network ({len(codec)-len(del_labels)} retained)' ) gt_set.encode(ncodec) nn.resize_output(ncodec.max_label() + 1, del_labels) message('\u2713', fg='green') else: logger.error(f'invalid resize parameter value {resize}') return None else: gt_set.encode(codec) logger.info( f'Creating new model {spec} with {gt_set.codec.max_label()+1} outputs' ) spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label() + 1) nn = vgsl.TorchVGSLModel(spec) # initialize weights message('Initializing model ', nl=False) nn.init_weights() nn.add_codec(gt_set.codec) # initialize codec message('\u2713', fg='green') if nn.one_channel_mode and gt_set.im_mode != nn.one_channel_mode: logger.warning( f'Neural network has been trained on mode {nn.one_channel_mode} images, training set contains mode {gt_set.im_mode} data. Consider setting `force_binarization`' ) if format_type != 'path' and nn.seg_type == 'bbox': logger.warning( 'Neural network has been trained on bounding box image information but training set is polygonal.' ) # half the number of data loading processes if device isn't cuda and we haven't enabled preloading if device == 'cpu' and not preload: loader_threads = threads // 2 else: loader_threads = threads train_loader = InfiniteDataLoader( gt_set, batch_size=hyper_params['batch_size'], shuffle=True, num_workers=loader_threads, pin_memory=True, collate_fn=collate_sequences) threads = max(threads - loader_threads, 1) # don't encode validation set as the alphabets may not match causing encoding failures val_set.no_encode() val_loader = DataLoader(val_set, batch_size=hyper_params['batch_size'], num_workers=loader_threads, pin_memory=True, collate_fn=collate_sequences) logger.debug('Constructing {} optimizer (lr: {}, momentum: {})'.format( hyper_params['optimizer'], hyper_params['lrate'], hyper_params['momentum'])) # set model type metadata field nn.model_type = 'recognition' # set mode to trainindg nn.train() # set number of OpenMP threads logger.debug(f'Set OpenMP threads to {threads}') nn.set_num_threads(threads) optim = getattr(torch.optim, hyper_params['optimizer'])(nn.nn.parameters(), lr=0) if 'seg_type' not in nn.user_metadata: nn.user_metadata[ 'seg_type'] = 'baselines' if format_type != 'path' else 'bbox' tr_it = TrainScheduler(optim) if hyper_params['schedule'] == '1cycle': add_1cycle(tr_it, int(len(gt_set) * hyper_params['epochs']), hyper_params['lrate'], hyper_params['momentum'], hyper_params['momentum'] - 0.10, hyper_params['weight_decay']) elif hyper_params['schedule'] == 'exponential': add_exponential_decay(tr_it, int(len(gt_set) * hyper_params['epochs']), len(gt_set), hyper_params['lrate'], 0.95, hyper_params['momentum'], hyper_params['weight_decay']) else: # constant learning rate scheduler tr_it.add_phase(1, 2 * (hyper_params['lrate'], ), 2 * (hyper_params['momentum'], ), hyper_params['weight_decay'], annealing_const) if hyper_params['quit'] == 'early': st_it = EarlyStopping(hyper_params['min_delta'], hyper_params['lag']) elif hyper_params['quit'] == 'dumb': st_it = EpochStopping(hyper_params['epochs'] - hyper_params['completed_epochs']) else: logger.error(f'Invalid training interruption scheme {quit}') return None trainer = cls(model=nn, optimizer=optim, device=device, filename_prefix=output, event_frequency=hyper_params['freq'], train_set=train_loader, val_set=val_loader, stopper=st_it) trainer.add_lr_scheduler(tr_it) return trainer
def build_binary_dataset( files: Optional[List[Union[str, pathlib.Path]]] = None, output_file: Union[str, pathlib.Path] = None, format_type: str = 'xml', num_workers: int = 0, ignore_splits: bool = False, random_split: Optional[Tuple[float, float, float]] = None, force_type: Optional[str] = None, recordbatch_size: int = 100, callback: Callable[[int, int], None] = lambda chunk, lines: None) -> None: """ Parses XML files and dumps the baseline-style line images and text into a binary dataset. Args: files: List of XML input files. output_file: Path to the output file. format_type: One of `xml`, `alto`, `page`, or `path`. num_workers: Number of workers for parallelized extraction of line images. Set to `0` to disable parallelism. ignore_splits: Switch to disable serialization of the explicit train/validation/test splits contained in the source files. random_split: Serializes a random split into the dataset with the proportions (train, val, test). force_type: Forces a dataset type. Can be `kraken_recognition_baseline` or `kraken_recognition_bbox`. recordbatch_size: Minimum number of records per RecordBatch written to the output file. Larger batches require more transient memory but slightly improve reading performance. callback: Function called everytime a new recordbatch is flushed into the Arrow IPC file. """ logger.info('Parsing XML files') extract_fn = _extract_line if format_type == 'xml': parse_fn = parse_xml elif format_type == 'alto': parse_fn = parse_alto elif format_type == 'page': parse_fn = parse_page elif format_type == 'path': if not ignore_splits: logger.warning( 'ignore_splits is False and format_type is path. Will not serialize splits.' ) parse_fn = parse_path extract_fn = _extract_path_line else: raise ValueError( f'invalid format {format_type} for parse_(xml,alto,page,path)') if force_type and force_type not in [ 'kraken_recognition_baseline', 'kraken_recognition_bbox' ]: raise ValueError(f'force_type set to invalid value {force_type}') docs = [] for doc in files: try: data = parse_fn(doc) except KrakenInputException: logger.warning(f'Invalid input file {doc}') continue try: with open(data['image'], 'rb') as fp: Image.open(fp) except FileNotFoundError as e: logger.warning(f'Could not open file {e.filename} in {doc}') continue docs.append(data) logger.info(f'Parsed {len(docs)} files.') logger.info('Assembling dataset alphabet.') alphabet = Counter() num_lines = 0 for doc in docs: for line in doc['lines']: num_lines += 1 alphabet.update(line['text']) callback(0, num_lines) for k, v in sorted(alphabet.items(), key=lambda x: x[1], reverse=True): char = make_printable(k) if char == k: char = '\t' + char logger.info(f'{char}\t{v}') if force_type: ds_type = force_type else: ds_type = 'kraken_recognition_baseline' if format_type != 'path' else 'kraken_recognition_bbox' metadata = { 'lines': { 'type': ds_type, 'alphabet': alphabet, 'text_type': 'raw', 'image_type': 'raw', 'splits': ['train', 'eval', 'test'], 'im_mode': '1', 'counts': Counter({ 'all': 0, 'train': 0, 'validation': 0, 'test': 0 }), } } ty = pa.struct([('text', pa.string()), ('im', pa.binary())]) schema = pa.schema([('lines', ty), ('train', pa.bool_()), ('validation', pa.bool_()), ('test', pa.bool_())]) def _make_record_batch(line_cache): ar = pa.array(line_cache, type=ty) if random_split: indices = np.random.choice(4, len(line_cache), p=(0.0, ) + random_split) else: indices = np.zeros(len(line_cache)) tr_ind = np.zeros(len(line_cache), dtype=bool) tr_ind[indices == 1] = True val_ind = np.zeros(len(line_cache), dtype=bool) val_ind[indices == 2] = True test_ind = np.zeros(len(line_cache), dtype=bool) test_ind[indices == 3] = True train_mask = pa.array(tr_ind) val_mask = pa.array(val_ind) test_mask = pa.array(test_ind) rbatch = pa.RecordBatch.from_arrays( [ar, train_mask, val_mask, test_mask], schema=schema) return rbatch, (len(line_cache), int(sum(indices == 1)), int(sum(indices == 2)), int(sum(indices == 3))) line_cache = [] logger.info('Writing lines to temporary file.') with tempfile.TemporaryDirectory() as tmp_output_dir: tmp_file = tmp_output_dir + '/dataset.arrow' with pa.OSFile(tmp_file, 'wb') as sink: with pa.ipc.new_file(sink, schema) as writer: if num_workers and num_workers > 1: logger.info( f'Spinning up processing pool with {num_workers} workers.' ) with Pool(num_workers) as pool: for page_lines, im_mode in pool.imap_unordered( extract_fn, docs): if page_lines: line_cache.extend(page_lines) # comparison RGB(A) > L > 1 if im_mode > metadata['lines']['im_mode']: metadata['lines']['im_mode'] = im_mode if len(line_cache) >= recordbatch_size: logger.info( f'Flushing {len(line_cache)} lines into {tmp_file}.' ) rbatch, counts = _make_record_batch(line_cache) metadata['lines']['counts'].update({ 'all': counts[0], 'train': counts[1], 'validation': counts[2], 'test': counts[3] }) writer.write(rbatch) callback(len(line_cache), num_lines) line_cache = [] else: for page_lines, im_mode in map(extract_fn, docs): if page_lines: line_cache.extend(page_lines) # comparison RGB(A) > L > 1 if im_mode > metadata['lines']['im_mode']: metadata['lines']['im_mode'] = im_mode if len(line_cache) >= recordbatch_size: logger.info( f'Flushing {len(line_cache)} lines into {tmp_file}.' ) rbatch, counts = _make_record_batch(line_cache) metadata['lines']['counts'].update({ 'all': counts[0], 'train': counts[1], 'validation': counts[2], 'test': counts[3] }) writer.write(rbatch) callback(len(line_cache), num_lines) line_cache = [] if line_cache: logger.info( f'Flushing last {len(line_cache)} lines into {tmp_file}.' ) rbatch, counts = _make_record_batch(line_cache) metadata['lines']['counts'].update({ 'all': counts[0], 'train': counts[1], 'validation': counts[2], 'test': counts[3] }) writer.write(rbatch) callback(len(line_cache), num_lines) logger.info('Dataset metadata') logger.info(f"type: {metadata['lines']['type']}\n" f"text_type: {metadata['lines']['text_type']}\n" f"image_type: {metadata['lines']['image_type']}\n" f"splits: {metadata['lines']['splits']}\n" f"im_mode: {metadata['lines']['im_mode']}\n" f"lines: {metadata['lines']['counts']}\n") with pa.memory_map(tmp_file, 'rb') as source: logger.info( f'Rewriting output ({output_file}) to update metadata.') ds = pa.ipc.open_file(source).read_all() metadata['lines']['counts'] = dict(metadata['lines']['counts']) metadata['lines'] = json.dumps(metadata['lines']) schema = schema.with_metadata(metadata) with pa.OSFile(output_file, 'wb') as sink: with pa.ipc.new_file(sink, schema) as writer: writer.write(ds)
def train(ctx, pad, output, spec, append, load, savefreq, report, quit, epochs, lag, min_delta, device, optimizer, lrate, momentum, weight_decay, schedule, partition, normalization, codec, resize, reorder, training_files, evaluation_files, preload, threads, ground_truth): """ Trains a model from image-text pairs. """ if not load and append: raise click.BadOptionUsage( 'append', 'append option requires loading an existing model') if resize != 'fail' and not load: raise click.BadOptionUsage( 'resize', 'resize option requires loading an existing model') import re import torch import shutil import numpy as np from torch.utils.data import DataLoader from kraken.lib import models, vgsl, train from kraken.lib.util import make_printable from kraken.lib.train import EarlyStopping, EpochStopping, TrainStopper, TrainScheduler, add_1cycle from kraken.lib.codec import PytorchCodec from kraken.lib.dataset import GroundTruthDataset, compute_error, generate_input_transforms logger.info('Building ground truth set from {} line images'.format( len(ground_truth) + len(training_files))) # load model if given. if a new model has to be created we need to do that # after data set initialization, otherwise to output size is still unknown. nn = None if load: logger.info('Loading existing model from {} '.format(load)) message('Loading model {}'.format(load), nl=False) nn = vgsl.TorchVGSLModel.load_model(load) message('\u2713', fg='green', nl=False) # preparse input sizes from vgsl string to seed ground truth data set # sizes and dimension ordering. if not nn: spec = spec.strip() if spec[0] != '[' or spec[-1] != ']': raise click.BadOptionUsage( 'spec', 'VGSL spec {} not bracketed'.format(spec)) blocks = spec[1:-1].split(' ') m = re.match(r'(\d+),(\d+),(\d+),(\d+)', blocks[0]) if not m: raise click.BadOptionUsage( 'spec', 'Invalid input spec {}'.format(blocks[0])) batch, height, width, channels = [int(x) for x in m.groups()] else: batch, channels, height, width = nn.input try: transforms = generate_input_transforms(batch, height, width, channels, pad) except KrakenInputException as e: raise click.BadOptionUsage('spec', str(e)) # disable automatic partition when given evaluation set explicitly if evaluation_files: partition = 1 ground_truth = list(ground_truth) # merge training_files into ground_truth list if training_files: ground_truth.extend(training_files) if len(ground_truth) == 0: raise click.UsageError( 'No training data was provided to the train command. Use `-t` or the `ground_truth` argument.' ) np.random.shuffle(ground_truth) if len(ground_truth) > 2500 and not preload: logger.info( 'Disabling preloading for large (>2500) training data set. Enable by setting --preload parameter' ) preload = False # implicit preloading enabled for small data sets if preload is None: preload = True tr_im = ground_truth[:int(len(ground_truth) * partition)] if evaluation_files: logger.debug('Using {} lines from explicit eval set'.format( len(evaluation_files))) te_im = evaluation_files else: te_im = ground_truth[int(len(ground_truth) * partition):] logger.debug('Taking {} lines from training for evaluation'.format( len(te_im))) gt_set = GroundTruthDataset(normalization=normalization, reorder=reorder, im_transforms=transforms, preload=preload) with log.progressbar(tr_im, label='Building training set') as bar: for im in bar: logger.debug('Adding line {} to training set'.format(im)) try: gt_set.add(im) except FileNotFoundError as e: logger.warning('{}: {}. Skipping.'.format( e.strerror, e.filename)) except KrakenInputException as e: logger.warning(str(e)) val_set = GroundTruthDataset(normalization=normalization, reorder=reorder, im_transforms=transforms, preload=preload) with log.progressbar(te_im, label='Building validation set') as bar: for im in bar: logger.debug('Adding line {} to validation set'.format(im)) try: val_set.add(im) except FileNotFoundError as e: logger.warning('{}: {}. Skipping.'.format( e.strerror, e.filename)) except KrakenInputException as e: logger.warning(str(e)) logger.info( 'Training set {} lines, validation set {} lines, alphabet {} symbols'. format(len(gt_set._images), len(val_set._images), len(gt_set.alphabet))) alpha_diff = set(gt_set.alphabet).symmetric_difference( set(val_set.alphabet)) if alpha_diff: logger.warn('alphabet mismatch {}'.format(alpha_diff)) logger.info('grapheme\tcount') for k, v in sorted(gt_set.alphabet.items(), key=lambda x: x[1], reverse=True): char = make_printable(k) if char == k: char = '\t' + char logger.info(u'{}\t{}'.format(char, v)) logger.debug('Encoding training set') # use model codec when given if append: # is already loaded nn = cast(vgsl.TorchVGSLModel, nn) gt_set.encode(codec) message('Slicing and dicing model ', nl=False) # now we can create a new model spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label() + 1) logger.info('Appending {} to existing model {} after {}'.format( spec, nn.spec, append)) nn.append(append, spec) nn.add_codec(gt_set.codec) message('\u2713', fg='green') logger.info('Assembled model spec: {}'.format(nn.spec)) elif load: # is already loaded nn = cast(vgsl.TorchVGSLModel, nn) # prefer explicitly given codec over network codec if mode is 'both' codec = codec if (codec and resize == 'both') else nn.codec try: gt_set.encode(codec) except KrakenEncodeException as e: message('Network codec not compatible with training set') alpha_diff = set(gt_set.alphabet).difference(set(codec.c2l.keys())) if resize == 'fail': logger.error( 'Training data and model codec alphabets mismatch: {}'. format(alpha_diff)) ctx.exit(code=1) elif resize == 'add': message('Adding missing labels to network ', nl=False) logger.info( 'Resizing codec to include {} new code points'.format( len(alpha_diff))) codec.c2l.update({ k: [v] for v, k in enumerate(alpha_diff, start=codec.max_label() + 1) }) nn.add_codec(PytorchCodec(codec.c2l)) logger.info( 'Resizing last layer in network to {} outputs'.format( codec.max_label() + 1)) nn.resize_output(codec.max_label() + 1) message('\u2713', fg='green') elif resize == 'both': message('Fitting network exactly to training set ', nl=False) logger.info( 'Resizing network or given codec to {} code sequences'. format(len(gt_set.alphabet))) gt_set.encode(None) ncodec, del_labels = codec.merge(gt_set.codec) logger.info( 'Deleting {} output classes from network ({} retained)'. format(len(del_labels), len(codec) - len(del_labels))) gt_set.encode(ncodec) nn.resize_output(ncodec.max_label() + 1, del_labels) message('\u2713', fg='green') else: raise click.BadOptionUsage( 'resize', 'Invalid resize value {}'.format(resize)) else: gt_set.encode(codec) logger.info('Creating new model {} with {} outputs'.format( spec, gt_set.codec.max_label() + 1)) spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label() + 1) nn = vgsl.TorchVGSLModel(spec) # initialize weights message('Initializing model ', nl=False) nn.init_weights() nn.add_codec(gt_set.codec) # initialize codec message('\u2713', fg='green') train_loader = DataLoader(gt_set, batch_size=1, shuffle=True, pin_memory=True) # don't encode validation set as the alphabets may not match causing encoding failures val_set.training_set = list(zip(val_set._images, val_set._gt)) logger.debug('Constructing {} optimizer (lr: {}, momentum: {})'.format( optimizer, lrate, momentum)) # set mode to trainindg nn.train() # set number of OpenMP threads logger.debug('Set OpenMP threads to {}'.format(threads)) nn.set_num_threads(threads) logger.debug('Moving model to device {}'.format(device)) rec = models.TorchSeqRecognizer(nn, train=True, device=device) optim = getattr(torch.optim, optimizer)(nn.nn.parameters(), lr=0) tr_it = TrainScheduler(optim) if schedule == '1cycle': add_1cycle(tr_it, epochs * len(gt_set), lrate, momentum, momentum - 0.10, weight_decay) else: # constant learning rate scheduler tr_it.add_phase(1, (lrate, lrate), (momentum, momentum), weight_decay, train.annealing_const) st_it = cast(TrainStopper, None) # type: TrainStopper if quit == 'early': st_it = EarlyStopping(train_loader, min_delta, lag) elif quit == 'dumb': st_it = EpochStopping(train_loader, epochs) else: raise click.BadOptionUsage( 'quit', 'Invalid training interruption scheme {}'.format(quit)) for epoch, loader in enumerate(st_it): with log.progressbar(label='epoch {}/{}'.format( epoch, epochs - 1 if epochs > 0 else '∞'), length=len(loader), show_pos=True) as bar: acc_loss = torch.tensor(0.0).to(device, non_blocking=True) for trial, (input, target) in enumerate(loader): tr_it.step() input = input.to(device, non_blocking=True) target = target.to(device, non_blocking=True) input = input.requires_grad_() o = nn.nn(input) # height should be 1 by now if o.size(2) != 1: raise KrakenInputException( 'Expected dimension 3 to be 1, actual {}'.format( o.size(2))) o = o.squeeze(2) optim.zero_grad() # NCW -> WNC loss = nn.criterion( o.permute(2, 0, 1), # type: ignore target, (o.size(2), ), (target.size(1), )) logger.info('trial {}'.format(trial)) if not torch.isinf(loss): loss.backward() optim.step() else: logger.debug('infinite loss in trial {}'.format(trial)) bar.update(1) if not epoch % savefreq: logger.info('Saving to {}_{}'.format(output, epoch)) try: nn.save_model('{}_{}.mlmodel'.format(output, epoch)) except Exception as e: logger.error('Saving model failed: {}'.format(str(e))) if not epoch % report: logger.debug('Starting evaluation run') nn.eval() chars, error = compute_error(rec, list(val_set)) nn.train() accuracy = (chars - error) / chars logger.info('Accuracy report ({}) {:0.4f} {} {}'.format( epoch, accuracy, chars, error)) message('Accuracy report ({}) {:0.4f} {} {}'.format( epoch, accuracy, chars, error)) st_it.update(accuracy) if quit == 'early': message('Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'. format(output, st_it.best_epoch, st_it.best_loss)) logger.info( 'Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'. format(output, st_it.best_epoch, st_it.best_loss)) shutil.copy('{}_{}.mlmodel'.format(output, st_it.best_epoch), '{}_best.mlmodel'.format(output))