Esempio n. 1
0
def recognizer(input_image, model, pad, no_segmentation, bidi_reordering,
               script_ignore, mode, text_direction, segments) -> None:
    bounds = segments

    # Script detection.
    if bounds['script_detection']:
        for l in bounds['boxes']:
            for t in l:
                scripts.add(t[0])
        it = rpred.mm_rpred(model,
                            input_image,
                            bounds,
                            pad,
                            bidi_reordering=bidi_reordering,
                            script_ignore=script_ignore)
    else:
        it = rpred.rpred(model['default'],
                         input_image,
                         bounds,
                         pad,
                         bidi_reordering=bidi_reordering)

    preds = []
    with log.progressbar(it, label='Processing',
                         length=len(bounds['boxes'])) as bar:
        for pred in bar:
            preds.append(pred)

    #--------------------
    print('Recognition results = {}.'.format('\n'.join(s.prediction
                                                       for s in preds)))

    if False:
        with open_file(output, 'w', encoding='utf-8') as fp:
            print('Serializing as {} into {}'.format(mode, output))
            if mode != 'text':
                from kraken import serialization
                fp.write(
                    serialization.serialize(preds, base_image,
                                            Image.open(base_image).size,
                                            text_direction, scripts, mode))
            else:
                fp.write('\n'.join(s.prediction for s in preds))
Esempio n. 2
0
def line_generator(ctx, font, maxlines, encoding, normalization, renormalize,
                   reorder, font_size, font_weight, language, max_length, strip,
                   disable_degradation, alpha, beta, distort, distortion_sigma,
                   legacy, output, text):
    """
    Generates artificial text line training data.
    """
    import errno
    import numpy as np

    from kraken import linegen
    from kraken.lib.util import make_printable

    lines: Set[str] = set()
    if not text:
        return
    with log.progressbar(text, label='Reading texts') as bar:
        for t in text:
            with click.open_file(t, encoding=encoding) as fp:
                logger.info('Reading {}'.format(t))
                for l in fp:
                    lines.add(l.rstrip('\r\n'))
    if normalization:
        lines = set([unicodedata.normalize(normalization, line) for line in lines])
    if strip:
        lines = set([line.strip() for line in lines])
    if max_length:
        lines = set([line for line in lines if len(line) < max_length])
    logger.info('Read {} lines'.format(len(lines)))
    message('Read {} unique lines'.format(len(lines)))
    if maxlines and maxlines < len(lines):
        message('Sampling {} lines\t'.format(maxlines), nl=False)
        llist = list(lines)
        lines = set(llist[idx] for idx in np.random.randint(0, len(llist), maxlines))
        message('\u2713', fg='green')
    try:
        os.makedirs(output)
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise

    # calculate the alphabet and print it for verification purposes
    alphabet: Set[str] = set()
    for line in lines:
        alphabet.update(line)
    chars = []
    combining = []
    for char in sorted(alphabet):
        k = make_printable(char)
        if k != char:
            combining.append(k)
        else:
            chars.append(k)
    message('Σ (len: {})'.format(len(alphabet)))
    message('Symbols: {}'.format(''.join(chars)))
    if combining:
        message('Combining Characters: {}'.format(', '.join(combining)))
    lg = linegen.LineGenerator(font, font_size, font_weight, language)
    with log.progressbar(lines, label='Writing images') as bar:
        for idx, line in enumerate(bar):
            logger.info(line)
            try:
                if renormalize:
                    im = lg.render_line(unicodedata.normalize(renormalize, line))
                else:
                    im = lg.render_line(line)
            except KrakenCairoSurfaceException as e:
                logger.info('{}: {} {}'.format(e.message, e.width, e.height))
                continue
            if not disable_degradation and not legacy:
                im = linegen.degrade_line(im, alpha=alpha, beta=beta)
                im = linegen.distort_line(im, abs(np.random.normal(distort)), abs(np.random.normal(distortion_sigma)))
            elif legacy:
                im = linegen.ocropy_degrade(im)
            im.save('{}/{:06d}.png'.format(output, idx))
            with open('{}/{:06d}.gt.txt'.format(output, idx), 'wb') as fp:
                if reorder:
                    fp.write(get_display(line).encode('utf-8'))
                else:
                    fp.write(line.encode('utf-8'))
Esempio n. 3
0
def transcription(ctx, text_direction, scale, bw, maxcolseps,
                  black_colseps, font, font_style, prefill, pad, lines, output,
                  images):
    """
    Creates transcription environments for ground truth generation.
    """
    from PIL import Image

    from kraken import rpred
    from kraken import pageseg
    from kraken import transcribe
    from kraken import binarization

    from kraken.lib import models
    from kraken.lib.util import is_bitonal

    ti = transcribe.TranscriptionInterface(font, font_style)

    if len(images) > 1 and lines:
        raise click.UsageError('--lines option is incompatible with multiple image files')

    if prefill:
        logger.info('Loading model {}'.format(prefill))
        message('Loading RNN', nl=False)
        prefill = models.load_any(prefill)
        message('\u2713', fg='green')

    with log.progressbar(images, label='Reading images') as bar:
        for fp in bar:
            logger.info('Reading {}'.format(fp.name))
            im = Image.open(fp)
            if im.mode not in ['1', 'L', 'P', 'RGB']:
                logger.warning('Input {} is in {} color mode. Converting to RGB'.format(fp.name, im.mode))
                im = im.convert('RGB')
            logger.info('Binarizing page')
            im_bin = binarization.nlbin(im)
            im_bin = im_bin.convert('1')
            logger.info('Segmenting page')
            if not lines:
                res = pageseg.segment(im_bin, text_direction, scale, maxcolseps, black_colseps, pad=pad)
            else:
                with open_file(lines, 'r') as fp:
                    try:
                        fp = cast(IO[Any], fp)
                        res = json.load(fp)
                    except ValueError as e:
                        raise click.UsageError('{} invalid segmentation: {}'.format(lines, str(e)))
            if prefill:
                it = rpred.rpred(prefill, im_bin, res)
                preds = []
                logger.info('Recognizing')
                for pred in it:
                    logger.debug('{}'.format(pred.prediction))
                    preds.append(pred)
                ti.add_page(im, res, records=preds)
            else:
                ti.add_page(im, res)
            fp.close()
    logger.info('Writing transcription to {}'.format(output.name))
    message('Writing output', nl=False)
    ti.write(output)
    message('\u2713', fg='green')
Esempio n. 4
0
def extract(ctx, binarize, normalization, normalize_whitespace, reorder,
            rotate, output, format, transcriptions):
    """
    Extracts image-text pairs from a transcription environment created using
    ``ketos transcribe``.
    """
    import regex
    import base64

    from io import BytesIO
    from PIL import Image
    from lxml import html, etree

    from kraken import binarization

    try:
        os.mkdir(output)
    except Exception:
        pass

    text_transforms = []
    if normalization:
        text_transforms.append(lambda x: unicodedata.normalize(normalization, x))
    if normalize_whitespace:
        text_transforms.append(lambda x: regex.sub('\s', ' ', x))
    if reorder:
        text_transforms.append(get_display)

    idx = 0
    manifest = []
    with log.progressbar(transcriptions, label='Reading transcriptions') as bar:
        for fp in bar:
            logger.info('Reading {}'.format(fp.name))
            doc = html.parse(fp)
            etree.strip_tags(doc, etree.Comment)
            td = doc.find(".//meta[@itemprop='text_direction']")
            if td is None:
                td = 'horizontal-lr'
            else:
                td = td.attrib['content']

            im = None
            dest_dict = {'output': output, 'idx': 0, 'src': fp.name, 'uuid': str(uuid.uuid4())}
            for section in doc.xpath('//section'):
                img = section.xpath('.//img')[0].get('src')
                fd = BytesIO(base64.b64decode(img.split(',')[1]))
                im = Image.open(fd)
                if not im:
                    logger.info('Skipping {} because image not found'.format(fp.name))
                    break
                if binarize:
                    im = binarization.nlbin(im)
                for line in section.iter('li'):
                    if line.get('contenteditable') and (not u''.join(line.itertext()).isspace() and u''.join(line.itertext())):
                        dest_dict['idx'] = idx
                        dest_dict['uuid'] = str(uuid.uuid4())
                        logger.debug('Writing line {:06d}'.format(idx))
                        l_img = im.crop([int(x) for x in line.get('data-bbox').split(',')])
                        if rotate and td.startswith('vertical'):
                            im.rotate(90, expand=True)
                        l_img.save(('{output}/' + format + '.png').format(**dest_dict))
                        manifest.append((format + '.png').format(**dest_dict))
                        text = u''.join(line.itertext()).strip()
                        for func in text_transforms:
                            text = func(text)
                        with open(('{output}/' + format + '.gt.txt').format(**dest_dict), 'wb') as t:
                            t.write(text.encode('utf-8'))
                        idx += 1
    logger.info('Extracted {} lines'.format(idx))
    with open('{}/manifest.txt'.format(output), 'w') as fp:
        fp.write('\n'.join(manifest))
Esempio n. 5
0
def test(ctx, model, evaluation_files, device, pad, threads, test_set):
    """
    Evaluate on a test set.
    """
    if not model:
        raise click.UsageError('No model to evaluate given.')

    import numpy as np
    from PIL import Image

    from kraken.serialization import render_report
    from kraken.lib import models
    from kraken.lib.dataset import global_align, compute_confusions, generate_input_transforms

    logger.info('Building test set from {} line images'.format(len(test_set) + len(evaluation_files)))

    nn = {}
    for p in model:
        message('Loading model {}\t'.format(p), nl=False)
        nn[p] = models.load_any(p)
        message('\u2713', fg='green')

    test_set = list(test_set)

    # set number of OpenMP threads
    logger.debug('Set OpenMP threads to {}'.format(threads))
    next(iter(nn.values())).nn.set_num_threads(threads)

    # merge training_files into ground_truth list
    if evaluation_files:
        test_set.extend(evaluation_files)

    if len(test_set) == 0:
        raise click.UsageError('No evaluation data was provided to the test command. Use `-e` or the `test_set` argument.')

    def _get_text(im):
        with open(os.path.splitext(im)[0] + '.gt.txt', 'r') as fp:
            return get_display(fp.read())

    acc_list = []
    for p, net in nn.items():
        algn_gt: List[str] = []
        algn_pred: List[str] = []
        chars = 0
        error = 0
        message('Evaluating {}'.format(p))
        logger.info('Evaluating {}'.format(p))
        batch, channels, height, width = net.nn.input
        ts = generate_input_transforms(batch, height, width, channels, pad)
        with log.progressbar(test_set, label='Evaluating') as bar:
            for im_path in bar:
                i = ts(Image.open(im_path))
                text = _get_text(im_path)
                pred = net.predict_string(i)
                chars += len(text)
                c, algn1, algn2 = global_align(text, pred)
                algn_gt.extend(algn1)
                algn_pred.extend(algn2)
                error += c
        acc_list.append((chars-error)/chars)
        confusions, scripts, ins, dels, subs = compute_confusions(algn_gt, algn_pred)
        rep = render_report(p, chars, error, confusions, scripts, ins, dels, subs)
        logger.info(rep)
        message(rep)
    logger.info('Average accuracy: {:0.2f}%, (stddev: {:0.2f})'.format(np.mean(acc_list) * 100, np.std(acc_list) * 100))
    message('Average accuracy: {:0.2f}%, (stddev: {:0.2f})'.format(np.mean(acc_list) * 100, np.std(acc_list) * 100))
Esempio n. 6
0
def train(ctx, pad, output, spec, append, load, freq, quit, epochs,
          lag, min_delta, device, optimizer, lrate, momentum, weight_decay,
          schedule, partition, normalization, normalize_whitespace, codec,
          resize, reorder, training_files, evaluation_files, preload, threads,
          ground_truth):
    """
    Trains a model from image-text pairs.
    """
    if not load and append:
        raise click.BadOptionUsage('append', 'append option requires loading an existing model')

    if resize != 'fail' and not load:
        raise click.BadOptionUsage('resize', 'resize option requires loading an existing model')

    import re
    import torch
    import shutil
    import numpy as np

    from torch.utils.data import DataLoader

    from kraken.lib import models, vgsl, train
    from kraken.lib.util import make_printable
    from kraken.lib.train import EarlyStopping, EpochStopping, TrainStopper, TrainScheduler, add_1cycle
    from kraken.lib.codec import PytorchCodec
    from kraken.lib.dataset import GroundTruthDataset, generate_input_transforms

    logger.info('Building ground truth set from {} line images'.format(len(ground_truth) + len(training_files)))

    completed_epochs = 0
    # load model if given. if a new model has to be created we need to do that
    # after data set initialization, otherwise to output size is still unknown.
    nn = None
    #hyper_fields = ['freq', 'quit', 'epochs', 'lag', 'min_delta', 'optimizer', 'lrate', 'momentum', 'weight_decay', 'schedule', 'partition', 'normalization', 'normalize_whitespace', 'reorder', 'preload', 'completed_epochs', 'output']

    if load:
        logger.info('Loading existing model from {} '.format(load))
        message('Loading existing model from {}'.format(load), nl=False)
        nn = vgsl.TorchVGSLModel.load_model(load)
        #if nn.user_metadata and load_hyper_parameters:
        #    for param in hyper_fields:
        #        if param in nn.user_metadata:
        #            logger.info('Setting \'{}\' to \'{}\''.format(param, nn.user_metadata[param]))
        #            message('Setting \'{}\' to \'{}\''.format(param, nn.user_metadata[param]))
        #            locals()[param] = nn.user_metadata[param]
        message('\u2713', fg='green', nl=False)

    # preparse input sizes from vgsl string to seed ground truth data set
    # sizes and dimension ordering.
    if not nn:
        spec = spec.strip()
        if spec[0] != '[' or spec[-1] != ']':
            raise click.BadOptionUsage('spec', 'VGSL spec {} not bracketed'.format(spec))
        blocks = spec[1:-1].split(' ')
        m = re.match(r'(\d+),(\d+),(\d+),(\d+)', blocks[0])
        if not m:
            raise click.BadOptionUsage('spec', 'Invalid input spec {}'.format(blocks[0]))
        batch, height, width, channels = [int(x) for x in m.groups()]
    else:
        batch, channels, height, width = nn.input
    try:
        transforms = generate_input_transforms(batch, height, width, channels, pad)
    except KrakenInputException as e:
        raise click.BadOptionUsage('spec', str(e))

    # disable automatic partition when given evaluation set explicitly
    if evaluation_files:
        partition = 1
    ground_truth = list(ground_truth)

    # merge training_files into ground_truth list
    if training_files:
        ground_truth.extend(training_files)

    if len(ground_truth) == 0:
        raise click.UsageError('No training data was provided to the train command. Use `-t` or the `ground_truth` argument.')

    np.random.shuffle(ground_truth)

    if len(ground_truth) > 2500 and not preload:
        logger.info('Disabling preloading for large (>2500) training data set. Enable by setting --preload parameter')
        preload = False
    # implicit preloading enabled for small data sets
    if preload is None:
        preload = True

    tr_im = ground_truth[:int(len(ground_truth) * partition)]
    if evaluation_files:
        logger.debug('Using {} lines from explicit eval set'.format(len(evaluation_files)))
        te_im = evaluation_files
    else:
        te_im = ground_truth[int(len(ground_truth) * partition):]
        logger.debug('Taking {} lines from training for evaluation'.format(len(te_im)))

    # set multiprocessing tensor sharing strategy
    if 'file_system' in torch.multiprocessing.get_all_sharing_strategies():
        logger.debug('Setting multiprocessing tensor sharing strategy to file_system')
        torch.multiprocessing.set_sharing_strategy('file_system')

    gt_set = GroundTruthDataset(normalization=normalization,
                                whitespace_normalization=normalize_whitespace,
                                reorder=reorder,
                                im_transforms=transforms,
                                preload=preload)
    with log.progressbar(tr_im, label='Building training set') as bar:
        for im in bar:
            logger.debug('Adding line {} to training set'.format(im))
            try:
                gt_set.add(im)
            except FileNotFoundError as e:
                logger.warning('{}: {}. Skipping.'.format(e.strerror, e.filename))
            except KrakenInputException as e:
                logger.warning(str(e))

    val_set = GroundTruthDataset(normalization=normalization,
                                 whitespace_normalization=normalize_whitespace,
                                 reorder=reorder,
                                 im_transforms=transforms,
                                 preload=preload)
    with log.progressbar(te_im, label='Building validation set') as bar:
        for im in bar:
            logger.debug('Adding line {} to validation set'.format(im))
            try:
                val_set.add(im)
            except FileNotFoundError as e:
                logger.warning('{}: {}. Skipping.'.format(e.strerror, e.filename))
            except KrakenInputException as e:
                logger.warning(str(e))

    logger.info('Training set {} lines, validation set {} lines, alphabet {} symbols'.format(len(gt_set._images), len(val_set._images), len(gt_set.alphabet)))
    alpha_diff_only_train = set(gt_set.alphabet).difference(set(val_set.alphabet))
    alpha_diff_only_val = set(val_set.alphabet).difference(set(gt_set.alphabet))
    if alpha_diff_only_train:
        logger.warning('alphabet mismatch: chars in training set only: {} (not included in accuracy test during training)'.format(alpha_diff_only_train))
    if alpha_diff_only_val:
        logger.warning('alphabet mismatch: chars in validation set only: {} (not trained)'.format(alpha_diff_only_val))        
    logger.info('grapheme\tcount')
    for k, v in sorted(gt_set.alphabet.items(), key=lambda x: x[1], reverse=True):
        char = make_printable(k)
        if char == k:
            char = '\t' + char
        logger.info(u'{}\t{}'.format(char, v))

    logger.debug('Encoding training set')

    # use model codec when given
    if append:
        # is already loaded
        nn = cast(vgsl.TorchVGSLModel, nn)
        gt_set.encode(codec)
        message('Slicing and dicing model ', nl=False)
        # now we can create a new model
        spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label()+1)
        logger.info('Appending {} to existing model {} after {}'.format(spec, nn.spec, append))
        nn.append(append, spec)
        nn.add_codec(gt_set.codec)
        message('\u2713', fg='green')
        logger.info('Assembled model spec: {}'.format(nn.spec))
    elif load:
        # is already loaded
        nn = cast(vgsl.TorchVGSLModel, nn)

        # prefer explicitly given codec over network codec if mode is 'both'
        codec = codec if (codec and resize == 'both') else nn.codec

        try:
            gt_set.encode(codec)
        except KrakenEncodeException as e:
            message('Network codec not compatible with training set')
            alpha_diff = set(gt_set.alphabet).difference(set(codec.c2l.keys()))
            if resize == 'fail':
                logger.error('Training data and model codec alphabets mismatch: {}'.format(alpha_diff))
                ctx.exit(code=1)
            elif resize == 'add':
                message('Adding missing labels to network ', nl=False)
                logger.info('Resizing codec to include {} new code points'.format(len(alpha_diff)))
                codec.c2l.update({k: [v] for v, k in enumerate(alpha_diff, start=codec.max_label()+1)})
                nn.add_codec(PytorchCodec(codec.c2l))
                logger.info('Resizing last layer in network to {} outputs'.format(codec.max_label()+1))
                nn.resize_output(codec.max_label()+1)
                gt_set.encode(nn.codec)
                message('\u2713', fg='green')
            elif resize == 'both':
                message('Fitting network exactly to training set ', nl=False)
                logger.info('Resizing network or given codec to {} code sequences'.format(len(gt_set.alphabet)))
                gt_set.encode(None)
                ncodec, del_labels = codec.merge(gt_set.codec)
                logger.info('Deleting {} output classes from network ({} retained)'.format(len(del_labels), len(codec)-len(del_labels)))
                gt_set.encode(ncodec)
                nn.resize_output(ncodec.max_label()+1, del_labels)
                message('\u2713', fg='green')
            else:
                raise click.BadOptionUsage('resize', 'Invalid resize value {}'.format(resize))
    else:
        gt_set.encode(codec)
        logger.info('Creating new model {} with {} outputs'.format(spec, gt_set.codec.max_label()+1))
        spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label()+1)
        nn = vgsl.TorchVGSLModel(spec)
        # initialize weights
        message('Initializing model ', nl=False)
        nn.init_weights()
        nn.add_codec(gt_set.codec)
        # initialize codec
        message('\u2713', fg='green')

    # half the number of data loading processes if device isn't cuda and we haven't enabled preloading
    if device == 'cpu' and not preload:
        loader_threads = threads // 2
    else:
        loader_threads = threads
    train_loader = DataLoader(gt_set, batch_size=1, shuffle=True, num_workers=loader_threads, pin_memory=True)
    threads -= loader_threads

    # don't encode validation set as the alphabets may not match causing encoding failures
    val_set.training_set = list(zip(val_set._images, val_set._gt))

    logger.debug('Constructing {} optimizer (lr: {}, momentum: {})'.format(optimizer, lrate, momentum))

    # set mode to trainindg
    nn.train()

    # set number of OpenMP threads
    logger.debug('Set OpenMP threads to {}'.format(threads))
    nn.set_num_threads(threads)

    logger.debug('Moving model to device {}'.format(device))
    optim = getattr(torch.optim, optimizer)(nn.nn.parameters(), lr=0)

    if 'accuracy' not in  nn.user_metadata:
        nn.user_metadata['accuracy'] = []

    tr_it = TrainScheduler(optim)
    if schedule == '1cycle':
        add_1cycle(tr_it, int(len(gt_set) * epochs), lrate, momentum, momentum - 0.10, weight_decay)
    else:
        # constant learning rate scheduler
        tr_it.add_phase(1, (lrate, lrate), (momentum, momentum), weight_decay, train.annealing_const)

    if quit == 'early':
        st_it = EarlyStopping(min_delta, lag)
    elif quit == 'dumb':
        st_it = EpochStopping(epochs - completed_epochs)
    else:
        raise click.BadOptionUsage('quit', 'Invalid training interruption scheme {}'.format(quit))

    #for param in hyper_fields:
    #    logger.debug('Setting \'{}\' to \'{}\' in model metadata'.format(param, locals()[param]))
    #    nn.user_metadata[param] = locals()[param]

    trainer = train.KrakenTrainer(model=nn,
                                  optimizer=optim,
                                  device=device,
                                  filename_prefix=output,
                                  event_frequency=freq,
                                  train_set=train_loader,
                                  val_set=val_set,
                                  stopper=st_it)

    trainer.add_lr_scheduler(tr_it)

    with  log.progressbar(label='stage {}/{}'.format(1, trainer.stopper.epochs if trainer.stopper.epochs > 0 else '∞'),
                          length=trainer.event_it, show_pos=True) as bar:

        def _draw_progressbar():
            bar.update(1)

        def _print_eval(epoch, accuracy, chars, error):
            message('Accuracy report ({}) {:0.4f} {} {}'.format(epoch, accuracy, chars, error))
            # reset progress bar
            bar.label = 'stage {}/{}'.format(epoch+1, trainer.stopper.epochs if trainer.stopper.epochs > 0 else '∞')
            bar.pos = 0
            bar.finished = False

        trainer.run(_print_eval, _draw_progressbar)

    if quit == 'early':
        message('Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'.format(output, trainer.stopper.best_epoch, trainer.stopper.best_loss))
        logger.info('Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'.format(output, trainer.stopper.best_epoch, trainer.stopper.best_loss))
        shutil.copy('{}_{}.mlmodel'.format(output, trainer.stopper.best_epoch), '{}_best.mlmodel'.format(output))
Esempio n. 7
0
def process_pipeline(subcommands, input, batch_input, suffix, verbose,
                     format_type, pdf_format, **args):
    """
    Helper function calling the partials returned by each subcommand and
    placing their respective outputs in temporary files.
    """
    import glob
    import uuid
    import tempfile

    input = list(input)
    # expand batch inputs
    if batch_input and suffix:
        for batch_expr in batch_input:
            for in_file in glob.glob(batch_expr, recursive=True):
                input.append(
                    (in_file,
                     '{}{}'.format(os.path.splitext(in_file)[0], suffix)))

    # parse pdfs
    if format_type == 'pdf':
        import pyvips

        if not batch_input:
            logger.warning(
                'PDF inputs not added with batch option. Manual output filename will be ignored and `-o` utilized.'
            )
        new_input = []
        num_pages = 0
        for (fpath, _) in input:
            doc = pyvips.Image.new_from_file(fpath,
                                             dpi=300,
                                             n=-1,
                                             access="sequential")
            if 'n-pages' in doc.get_fields():
                num_pages += doc.get('n-pages')

        with log.progressbar(length=num_pages,
                             label='Extracting PDF pages') as bar:
            for (fpath, _) in input:
                try:
                    doc = pyvips.Image.new_from_file(fpath,
                                                     dpi=300,
                                                     n=-1,
                                                     access="sequential")
                    if 'n-pages' not in doc.get_fields():
                        logger.warning(
                            '{fpath} does not contain pages. Skipping.')
                        continue
                    n_pages = doc.get('n-pages')

                    dest_dict = {'idx': -1, 'src': fpath, 'uuid': None}
                    for i in range(0, n_pages):
                        dest_dict['idx'] += 1
                        dest_dict['uuid'] = str(uuid.uuid4())
                        fd, filename = tempfile.mkstemp(suffix='.png')
                        os.close(fd)
                        doc = pyvips.Image.new_from_file(fpath,
                                                         dpi=300,
                                                         page=i,
                                                         access="sequential")
                        logger.info(
                            f'Saving temporary image {fpath}:{dest_dict["idx"]} to {filename}'
                        )
                        doc.write_to_file(filename)
                        new_input.append(
                            (filename,
                             pdf_format.format(**dest_dict) + suffix))
                        bar.update(1)
                except pyvips.error.Error:
                    logger.warning(f'{fpath} is not a PDF file. Skipping.')
        input = new_input

    ctx = click.get_current_context()

    for io_pair in input:
        ctx.meta['first_process'] = True
        ctx.meta['last_process'] = False
        ctx.meta['orig_file'] = io_pair[0]
        if 'base_image' in ctx.meta:
            del ctx.meta['base_image']
        try:
            tmps = [tempfile.mkstemp() for cmd in subcommands[1:]]
            for tmp in tmps:
                os.close(tmp[0])
            fc = [io_pair[0]] + [tmp[1] for tmp in tmps] + [io_pair[1]]
            for idx, (task, input,
                      output) in enumerate(zip(subcommands, fc, fc[1:])):
                if len(fc) - 2 == idx:
                    ctx.meta['last_process'] = True
                task(input=input, output=output)
        except Exception as e:
            logger.error(f'Failed processing {io_pair[0]}: {str(e)}')
            if ctx.meta['raise_failed'] is True:
                raise
        finally:
            for f in fc[1:-1]:
                os.unlink(f)
            # clean up temporary PDF image files
            if format_type == 'pdf':
                logger.debug(f'unlinking {fc[0]}')
                os.unlink(fc[0])
Esempio n. 8
0
def recognizer(model, pad, no_segmentation, bidi_reordering, script_ignore,
               input, output) -> None:

    import json

    from kraken import rpred

    ctx = click.get_current_context()

    bounds = None
    if 'base_image' not in ctx.meta:
        ctx.meta['base_image'] = input

    if ctx.meta['first_process']:
        if ctx.meta['input_format_type'] != 'image':
            doc = get_input_parser(ctx.meta['input_format_type'])(input)
            ctx.meta['base_image'] = doc['image']
            doc['text_direction'] = 'horizontal-lr'
            bounds = doc
    try:
        im = Image.open(ctx.meta['base_image'])
    except IOError as e:
        raise click.BadParameter(str(e))

    if not bounds and ctx.meta['base_image'] != input:
        with open_file(input, 'r') as fp:
            try:
                fp = cast(IO[Any], fp)
                bounds = json.load(fp)
            except ValueError as e:
                raise click.UsageError(
                    f'{input} invalid segmentation: {str(e)}')
    elif not bounds:
        if no_segmentation:
            bounds = {
                'script_detection': False,
                'text_direction': 'horizontal-lr',
                'boxes': [(0, 0) + im.size]
            }
        else:
            raise click.UsageError(
                'No line segmentation given. Add one with the input or run `segment` first.'
            )
    elif no_segmentation:
        logger.warning(
            'no_segmentation mode enabled but segmentation defined. Ignoring --no-segmentation option.'
        )

    scripts = set()
    # script detection
    if 'script_detection' in bounds and bounds['script_detection']:
        it = rpred.mm_rpred(model,
                            im,
                            bounds,
                            pad,
                            bidi_reordering=bidi_reordering,
                            script_ignore=script_ignore)
    else:
        it = rpred.rpred(model['default'],
                         im,
                         bounds,
                         pad,
                         bidi_reordering=bidi_reordering)

    preds = []

    with log.progressbar(it, label='Processing') as bar:
        for pred in bar:
            preds.append(pred)

    ctx = click.get_current_context()
    with open_file(output, 'w', encoding='utf-8') as fp:
        fp = cast(IO[Any], fp)
        message(f'Writing recognition results for {ctx.meta["orig_file"]}\t',
                nl=False)
        logger.info('Serializing as {} into {}'.format(ctx.meta['output_mode'],
                                                       output))
        if ctx.meta['output_mode'] != 'native':
            from kraken import serialization
            fp.write(
                serialization.serialize(
                    preds, ctx.meta['base_image'],
                    Image.open(ctx.meta['base_image']).size,
                    ctx.meta['text_direction'], scripts,
                    bounds['regions'] if 'regions' in bounds else None,
                    ctx.meta['output_mode']))
        else:
            fp.write('\n'.join(s.prediction for s in preds))
        message('\u2713', fg='green')
Esempio n. 9
0
def transcription(ctx, text_direction, scale, bw, maxcolseps, black_colseps,
                  font, font_style, prefill, pad, lines, output, images):
    """
    Creates transcription environments for ground truth generation.
    """
    from PIL import Image

    from kraken import rpred
    from kraken import pageseg
    from kraken import transcribe
    from kraken import binarization

    from kraken.lib import models
    from kraken.lib.util import is_bitonal

    ti = transcribe.TranscriptionInterface(font, font_style)

    if len(images) > 1 and lines:
        raise click.UsageError(
            '--lines option is incompatible with multiple image files')

    if prefill:
        logger.info('Loading model {}'.format(prefill))
        message('Loading RNN', nl=False)
        prefill = models.load_any(prefill)
        message('\u2713', fg='green')

    with log.progressbar(images, label='Reading images') as bar:
        for fp in bar:
            logger.info('Reading {}'.format(fp.name))
            im = Image.open(fp)
            if im.mode not in ['1', 'L', 'P', 'RGB']:
                logger.warning(
                    'Input {} is in {} color mode. Converting to RGB'.format(
                        fp.name, im.mode))
                im = im.convert('RGB')
            logger.info('Binarizing page')
            im_bin = binarization.nlbin(im)
            im_bin = im_bin.convert('1')
            logger.info('Segmenting page')
            if not lines:
                res = pageseg.segment(im_bin,
                                      text_direction,
                                      scale,
                                      maxcolseps,
                                      black_colseps,
                                      pad=pad)
            else:
                with open_file(lines, 'r') as fp:
                    try:
                        fp = cast(IO[Any], fp)
                        res = json.load(fp)
                    except ValueError as e:
                        raise click.UsageError(
                            '{} invalid segmentation: {}'.format(
                                lines, str(e)))
            if prefill:
                it = rpred.rpred(prefill, im_bin, res)
                preds = []
                logger.info('Recognizing')
                for pred in it:
                    logger.debug('{}'.format(pred.prediction))
                    preds.append(pred)
                ti.add_page(im, res, records=preds)
            else:
                ti.add_page(im, res)
            fp.close()
    logger.info('Writing transcription to {}'.format(output.name))
    message('Writing output', nl=False)
    ti.write(output)
    message('\u2713', fg='green')
Esempio n. 10
0
def extract(ctx, binarize, normalization, normalize_whitespace, reorder,
            rotate, output, format, transcriptions):
    """
    Extracts image-text pairs from a transcription environment created using
    ``ketos transcribe``.
    """
    import regex
    import base64

    from io import BytesIO
    from PIL import Image
    from lxml import html, etree

    from kraken import binarization

    try:
        os.mkdir(output)
    except Exception:
        pass

    text_transforms = []
    if normalization:
        text_transforms.append(
            lambda x: unicodedata.normalize(normalization, x))
    if normalize_whitespace:
        text_transforms.append(lambda x: regex.sub('\s', ' ', x))
    if reorder:
        text_transforms.append(get_display)

    idx = 0
    manifest = []
    with log.progressbar(transcriptions,
                         label='Reading transcriptions') as bar:
        for fp in bar:
            logger.info('Reading {}'.format(fp.name))
            doc = html.parse(fp)
            etree.strip_tags(doc, etree.Comment)
            td = doc.find(".//meta[@itemprop='text_direction']")
            if td is None:
                td = 'horizontal-lr'
            else:
                td = td.attrib['content']

            im = None
            dest_dict = {
                'output': output,
                'idx': 0,
                'src': fp.name,
                'uuid': str(uuid.uuid4())
            }
            for section in doc.xpath('//section'):
                img = section.xpath('.//img')[0].get('src')
                fd = BytesIO(base64.b64decode(img.split(',')[1]))
                im = Image.open(fd)
                if not im:
                    logger.info('Skipping {} because image not found'.format(
                        fp.name))
                    break
                if binarize:
                    im = binarization.nlbin(im)
                for line in section.iter('li'):
                    if line.get('contenteditable') and (
                            not u''.join(line.itertext()).isspace()
                            and u''.join(line.itertext())):
                        dest_dict['idx'] = idx
                        dest_dict['uuid'] = str(uuid.uuid4())
                        logger.debug('Writing line {:06d}'.format(idx))
                        l_img = im.crop(
                            [int(x) for x in line.get('data-bbox').split(',')])
                        if rotate and td.startswith('vertical'):
                            im.rotate(90, expand=True)
                        l_img.save(('{output}/' + format +
                                    '.png').format(**dest_dict))
                        manifest.append((format + '.png').format(**dest_dict))
                        text = u''.join(line.itertext()).strip()
                        for func in text_transforms:
                            text = func(text)
                        with open(('{output}/' + format +
                                   '.gt.txt').format(**dest_dict), 'wb') as t:
                            t.write(text.encode('utf-8'))
                        idx += 1
    logger.info('Extracted {} lines'.format(idx))
    with open('{}/manifest.txt'.format(output), 'w') as fp:
        fp.write('\n'.join(manifest))
Esempio n. 11
0
def test(ctx, model, evaluation_files, device, pad, threads, test_set):
    """
    Evaluate on a test set.
    """
    if not model:
        raise click.UsageError('No model to evaluate given.')

    import numpy as np
    from PIL import Image

    from kraken.serialization import render_report
    from kraken.lib import models
    from kraken.lib.dataset import global_align, compute_confusions, generate_input_transforms

    logger.info('Building test set from {} line images'.format(
        len(test_set) + len(evaluation_files)))

    nn = {}
    for p in model:
        message('Loading model {}\t'.format(p), nl=False)
        nn[p] = models.load_any(p)
        message('\u2713', fg='green')

    test_set = list(test_set)

    # set number of OpenMP threads
    logger.debug('Set OpenMP threads to {}'.format(threads))
    next(iter(nn.values())).nn.set_num_threads(threads)

    # merge training_files into ground_truth list
    if evaluation_files:
        test_set.extend(evaluation_files)

    if len(test_set) == 0:
        raise click.UsageError(
            'No evaluation data was provided to the test command. Use `-e` or the `test_set` argument.'
        )

    def _get_text(im):
        with open(os.path.splitext(im)[0] + '.gt.txt', 'r') as fp:
            return get_display(fp.read())

    acc_list = []
    for p, net in nn.items():
        algn_gt: List[str] = []
        algn_pred: List[str] = []
        chars = 0
        error = 0
        message('Evaluating {}'.format(p))
        logger.info('Evaluating {}'.format(p))
        batch, channels, height, width = net.nn.input
        ts = generate_input_transforms(batch, height, width, channels, pad)
        with log.progressbar(test_set, label='Evaluating') as bar:
            for im_path in bar:
                i = ts(Image.open(im_path))
                text = _get_text(im_path)
                pred = net.predict_string(i)
                chars += len(text)
                c, algn1, algn2 = global_align(text, pred)
                algn_gt.extend(algn1)
                algn_pred.extend(algn2)
                error += c
        acc_list.append((chars - error) / chars)
        confusions, scripts, ins, dels, subs = compute_confusions(
            algn_gt, algn_pred)
        rep = render_report(p, chars, error, confusions, scripts, ins, dels,
                            subs)
        logger.info(rep)
        message(rep)
    logger.info('Average accuracy: {:0.2f}%, (stddev: {:0.2f})'.format(
        np.mean(acc_list) * 100,
        np.std(acc_list) * 100))
    message('Average accuracy: {:0.2f}%, (stddev: {:0.2f})'.format(
        np.mean(acc_list) * 100,
        np.std(acc_list) * 100))
Esempio n. 12
0
def train(ctx, pad, output, spec, append, load, freq, quit, epochs, lag,
          min_delta, device, optimizer, lrate, momentum, weight_decay,
          schedule, partition, normalization, normalize_whitespace, codec,
          resize, reorder, training_files, evaluation_files, preload, threads,
          ground_truth):
    """
    Trains a model from image-text pairs.
    """
    if not load and append:
        raise click.BadOptionUsage(
            'append', 'append option requires loading an existing model')

    if resize != 'fail' and not load:
        raise click.BadOptionUsage(
            'resize', 'resize option requires loading an existing model')

    import re
    import torch
    import shutil
    import numpy as np

    from torch.utils.data import DataLoader

    from kraken.lib import models, vgsl, train
    from kraken.lib.util import make_printable
    from kraken.lib.train import EarlyStopping, EpochStopping, TrainStopper, TrainScheduler, add_1cycle
    from kraken.lib.codec import PytorchCodec
    from kraken.lib.dataset import GroundTruthDataset, generate_input_transforms

    logger.info('Building ground truth set from {} line images'.format(
        len(ground_truth) + len(training_files)))

    completed_epochs = 0
    # load model if given. if a new model has to be created we need to do that
    # after data set initialization, otherwise to output size is still unknown.
    nn = None
    #hyper_fields = ['freq', 'quit', 'epochs', 'lag', 'min_delta', 'optimizer', 'lrate', 'momentum', 'weight_decay', 'schedule', 'partition', 'normalization', 'normalize_whitespace', 'reorder', 'preload', 'completed_epochs', 'output']

    if load:
        logger.info('Loading existing model from {} '.format(load))
        message('Loading existing model from {}'.format(load), nl=False)
        nn = vgsl.TorchVGSLModel.load_model(load)
        #if nn.user_metadata and load_hyper_parameters:
        #    for param in hyper_fields:
        #        if param in nn.user_metadata:
        #            logger.info('Setting \'{}\' to \'{}\''.format(param, nn.user_metadata[param]))
        #            message('Setting \'{}\' to \'{}\''.format(param, nn.user_metadata[param]))
        #            locals()[param] = nn.user_metadata[param]
        message('\u2713', fg='green', nl=False)

    # preparse input sizes from vgsl string to seed ground truth data set
    # sizes and dimension ordering.
    if not nn:
        spec = spec.strip()
        if spec[0] != '[' or spec[-1] != ']':
            raise click.BadOptionUsage(
                'spec', 'VGSL spec {} not bracketed'.format(spec))
        blocks = spec[1:-1].split(' ')
        m = re.match(r'(\d+),(\d+),(\d+),(\d+)', blocks[0])
        if not m:
            raise click.BadOptionUsage(
                'spec', 'Invalid input spec {}'.format(blocks[0]))
        batch, height, width, channels = [int(x) for x in m.groups()]
    else:
        batch, channels, height, width = nn.input
    try:
        transforms = generate_input_transforms(batch, height, width, channels,
                                               pad)
    except KrakenInputException as e:
        raise click.BadOptionUsage('spec', str(e))

    # disable automatic partition when given evaluation set explicitly
    if evaluation_files:
        partition = 1
    ground_truth = list(ground_truth)

    # merge training_files into ground_truth list
    if training_files:
        ground_truth.extend(training_files)

    if len(ground_truth) == 0:
        raise click.UsageError(
            'No training data was provided to the train command. Use `-t` or the `ground_truth` argument.'
        )

    np.random.shuffle(ground_truth)

    if len(ground_truth) > 2500 and not preload:
        logger.info(
            'Disabling preloading for large (>2500) training data set. Enable by setting --preload parameter'
        )
        preload = False
    # implicit preloading enabled for small data sets
    if preload is None:
        preload = True

    tr_im = ground_truth[:int(len(ground_truth) * partition)]
    if evaluation_files:
        logger.debug('Using {} lines from explicit eval set'.format(
            len(evaluation_files)))
        te_im = evaluation_files
    else:
        te_im = ground_truth[int(len(ground_truth) * partition):]
        logger.debug('Taking {} lines from training for evaluation'.format(
            len(te_im)))

    # set multiprocessing tensor sharing strategy
    if 'file_system' in torch.multiprocessing.get_all_sharing_strategies():
        logger.debug(
            'Setting multiprocessing tensor sharing strategy to file_system')
        torch.multiprocessing.set_sharing_strategy('file_system')

    gt_set = GroundTruthDataset(normalization=normalization,
                                whitespace_normalization=normalize_whitespace,
                                reorder=reorder,
                                im_transforms=transforms,
                                preload=preload)
    with log.progressbar(tr_im, label='Building training set') as bar:
        for im in bar:
            logger.debug('Adding line {} to training set'.format(im))
            try:
                gt_set.add(im)
            except FileNotFoundError as e:
                logger.warning('{}: {}. Skipping.'.format(
                    e.strerror, e.filename))
            except KrakenInputException as e:
                logger.warning(str(e))

    val_set = GroundTruthDataset(normalization=normalization,
                                 whitespace_normalization=normalize_whitespace,
                                 reorder=reorder,
                                 im_transforms=transforms,
                                 preload=preload)
    with log.progressbar(te_im, label='Building validation set') as bar:
        for im in bar:
            logger.debug('Adding line {} to validation set'.format(im))
            try:
                val_set.add(im)
            except FileNotFoundError as e:
                logger.warning('{}: {}. Skipping.'.format(
                    e.strerror, e.filename))
            except KrakenInputException as e:
                logger.warning(str(e))

    logger.info(
        'Training set {} lines, validation set {} lines, alphabet {} symbols'.
        format(len(gt_set._images), len(val_set._images),
               len(gt_set.alphabet)))
    alpha_diff = set(gt_set.alphabet).symmetric_difference(
        set(val_set.alphabet))
    if alpha_diff:
        logger.warn('alphabet mismatch {}'.format(alpha_diff))
    logger.info('grapheme\tcount')
    for k, v in sorted(gt_set.alphabet.items(),
                       key=lambda x: x[1],
                       reverse=True):
        char = make_printable(k)
        if char == k:
            char = '\t' + char
        logger.info(u'{}\t{}'.format(char, v))

    logger.debug('Encoding training set')

    # use model codec when given
    if append:
        # is already loaded
        nn = cast(vgsl.TorchVGSLModel, nn)
        gt_set.encode(codec)
        message('Slicing and dicing model ', nl=False)
        # now we can create a new model
        spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label() + 1)
        logger.info('Appending {} to existing model {} after {}'.format(
            spec, nn.spec, append))
        nn.append(append, spec)
        nn.add_codec(gt_set.codec)
        message('\u2713', fg='green')
        logger.info('Assembled model spec: {}'.format(nn.spec))
    elif load:
        # is already loaded
        nn = cast(vgsl.TorchVGSLModel, nn)

        # prefer explicitly given codec over network codec if mode is 'both'
        codec = codec if (codec and resize == 'both') else nn.codec

        try:
            gt_set.encode(codec)
        except KrakenEncodeException as e:
            message('Network codec not compatible with training set')
            alpha_diff = set(gt_set.alphabet).difference(set(codec.c2l.keys()))
            if resize == 'fail':
                logger.error(
                    'Training data and model codec alphabets mismatch: {}'.
                    format(alpha_diff))
                ctx.exit(code=1)
            elif resize == 'add':
                message('Adding missing labels to network ', nl=False)
                logger.info(
                    'Resizing codec to include {} new code points'.format(
                        len(alpha_diff)))
                codec.c2l.update({
                    k: [v]
                    for v, k in enumerate(alpha_diff,
                                          start=codec.max_label() + 1)
                })
                nn.add_codec(PytorchCodec(codec.c2l))
                logger.info(
                    'Resizing last layer in network to {} outputs'.format(
                        codec.max_label() + 1))
                nn.resize_output(codec.max_label() + 1)
                gt_set.encode(nn.codec)
                message('\u2713', fg='green')
            elif resize == 'both':
                message('Fitting network exactly to training set ', nl=False)
                logger.info(
                    'Resizing network or given codec to {} code sequences'.
                    format(len(gt_set.alphabet)))
                gt_set.encode(None)
                ncodec, del_labels = codec.merge(gt_set.codec)
                logger.info(
                    'Deleting {} output classes from network ({} retained)'.
                    format(len(del_labels),
                           len(codec) - len(del_labels)))
                gt_set.encode(ncodec)
                nn.resize_output(ncodec.max_label() + 1, del_labels)
                message('\u2713', fg='green')
            else:
                raise click.BadOptionUsage(
                    'resize', 'Invalid resize value {}'.format(resize))
    else:
        gt_set.encode(codec)
        logger.info('Creating new model {} with {} outputs'.format(
            spec,
            gt_set.codec.max_label() + 1))
        spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label() + 1)
        nn = vgsl.TorchVGSLModel(spec)
        # initialize weights
        message('Initializing model ', nl=False)
        nn.init_weights()
        nn.add_codec(gt_set.codec)
        # initialize codec
        message('\u2713', fg='green')

    # half the number of data loading processes if device isn't cuda and we haven't enabled preloading
    if device == 'cpu' and not preload:
        loader_threads = threads // 2
    else:
        loader_threads = threads
    train_loader = DataLoader(gt_set,
                              batch_size=1,
                              shuffle=True,
                              num_workers=loader_threads,
                              pin_memory=True)
    threads -= loader_threads

    # don't encode validation set as the alphabets may not match causing encoding failures
    val_set.training_set = list(zip(val_set._images, val_set._gt))

    logger.debug('Constructing {} optimizer (lr: {}, momentum: {})'.format(
        optimizer, lrate, momentum))

    # set mode to trainindg
    nn.train()

    # set number of OpenMP threads
    logger.debug('Set OpenMP threads to {}'.format(threads))
    nn.set_num_threads(threads)

    logger.debug('Moving model to device {}'.format(device))
    optim = getattr(torch.optim, optimizer)(nn.nn.parameters(), lr=0)

    if 'accuracy' not in nn.user_metadata:
        nn.user_metadata['accuracy'] = []

    tr_it = TrainScheduler(optim)
    if schedule == '1cycle':
        add_1cycle(tr_it, int(len(gt_set) * epochs), lrate, momentum,
                   momentum - 0.10, weight_decay)
    else:
        # constant learning rate scheduler
        tr_it.add_phase(1, (lrate, lrate), (momentum, momentum), weight_decay,
                        train.annealing_const)

    if quit == 'early':
        st_it = EarlyStopping(min_delta, lag)
    elif quit == 'dumb':
        st_it = EpochStopping(epochs - completed_epochs)
    else:
        raise click.BadOptionUsage(
            'quit', 'Invalid training interruption scheme {}'.format(quit))

    #for param in hyper_fields:
    #    logger.debug('Setting \'{}\' to \'{}\' in model metadata'.format(param, locals()[param]))
    #    nn.user_metadata[param] = locals()[param]

    trainer = train.KrakenTrainer(model=nn,
                                  optimizer=optim,
                                  device=device,
                                  filename_prefix=output,
                                  event_frequency=freq,
                                  train_set=train_loader,
                                  val_set=val_set,
                                  stopper=st_it)

    trainer.add_lr_scheduler(tr_it)

    with log.progressbar(label='stage {}/{}'.format(
            1, trainer.stopper.epochs if trainer.stopper.epochs > 0 else '∞'),
                         length=trainer.event_it,
                         show_pos=True) as bar:

        def _draw_progressbar():
            bar.update(1)

        def _print_eval(epoch, accuracy, chars, error):
            message('Accuracy report ({}) {:0.4f} {} {}'.format(
                epoch, accuracy, chars, error))
            # reset progress bar
            bar.label = 'stage {}/{}'.format(
                epoch + 1,
                trainer.stopper.epochs if trainer.stopper.epochs > 0 else '∞')
            bar.pos = 0
            bar.finished = False

        trainer.run(_print_eval, _draw_progressbar)

    if quit == 'early':
        message('Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'.
                format(output, trainer.stopper.best_epoch,
                       trainer.stopper.best_loss))
        logger.info(
            'Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'.
            format(output, trainer.stopper.best_epoch,
                   trainer.stopper.best_loss))
        shutil.copy('{}_{}.mlmodel'.format(output, trainer.stopper.best_epoch),
                    '{}_best.mlmodel'.format(output))
Esempio n. 13
0
def line_generator(ctx, font, maxlines, encoding, normalization, renormalize,
                   reorder, font_size, font_weight, language, max_length,
                   strip, disable_degradation, alpha, beta, distort,
                   distortion_sigma, legacy, output, text):
    """
    Generates artificial text line training data.
    """
    import errno
    import numpy as np

    from kraken import linegen
    from kraken.lib.util import make_printable

    lines: Set[str] = set()
    if not text:
        return
    with log.progressbar(text, label='Reading texts') as bar:
        for t in text:
            with click.open_file(t, encoding=encoding) as fp:
                logger.info('Reading {}'.format(t))
                for l in fp:
                    lines.add(l.rstrip('\r\n'))
    if normalization:
        lines = set(
            [unicodedata.normalize(normalization, line) for line in lines])
    if strip:
        lines = set([line.strip() for line in lines])
    if max_length:
        lines = set([line for line in lines if len(line) < max_length])
    logger.info('Read {} lines'.format(len(lines)))
    message('Read {} unique lines'.format(len(lines)))
    if maxlines and maxlines < len(lines):
        message('Sampling {} lines\t'.format(maxlines), nl=False)
        llist = list(lines)
        lines = set(llist[idx]
                    for idx in np.random.randint(0, len(llist), maxlines))
        message('\u2713', fg='green')
    try:
        os.makedirs(output)
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise

    # calculate the alphabet and print it for verification purposes
    alphabet: Set[str] = set()
    for line in lines:
        alphabet.update(line)
    chars = []
    combining = []
    for char in sorted(alphabet):
        k = make_printable(char)
        if k != char:
            combining.append(k)
        else:
            chars.append(k)
    message('Σ (len: {})'.format(len(alphabet)))
    message('Symbols: {}'.format(''.join(chars)))
    if combining:
        message('Combining Characters: {}'.format(', '.join(combining)))
    lg = linegen.LineGenerator(font, font_size, font_weight, language)
    with log.progressbar(lines, label='Writing images') as bar:
        for idx, line in enumerate(bar):
            logger.info(line)
            try:
                if renormalize:
                    im = lg.render_line(
                        unicodedata.normalize(renormalize, line))
                else:
                    im = lg.render_line(line)
            except KrakenCairoSurfaceException as e:
                logger.info('{}: {} {}'.format(e.message, e.width, e.height))
                continue
            if not disable_degradation and not legacy:
                im = linegen.degrade_line(im, alpha=alpha, beta=beta)
                im = linegen.distort_line(
                    im, abs(np.random.normal(distort)),
                    abs(np.random.normal(distortion_sigma)))
            elif legacy:
                im = linegen.ocropy_degrade(im)
            im.save('{}/{:06d}.png'.format(output, idx))
            with open('{}/{:06d}.gt.txt'.format(output, idx), 'wb') as fp:
                if reorder:
                    fp.write(get_display(line).encode('utf-8'))
                else:
                    fp.write(line.encode('utf-8'))
Esempio n. 14
0
def recognizer(model, pad, no_segmentation, bidi_reordering, script_ignore,
               base_image, input, output, lines) -> None:

    import json
    import tempfile

    from kraken import rpred

    try:
        im = Image.open(base_image)
    except IOError as e:
        raise click.BadParameter(str(e))

    ctx = click.get_current_context()

    # input may either be output from the segmenter then it is a JSON file or
    # be an image file when running the OCR subcommand alone. might still come
    # from some other subcommand though.
    scripts = set()
    if not lines and base_image != input:
        lines = input
    if not lines:
        if no_segmentation:
            lines = tempfile.NamedTemporaryFile(mode='w', delete=False)
            logger.info(
                'Running in no_segmentation mode. Creating temporary segmentation {}.'
                .format(lines.name))
            json.dump(
                {
                    'script_detection': False,
                    'text_direction': 'horizontal-lr',
                    'boxes': [(0, 0) + im.size]
                }, lines)
            lines.close()
            lines = lines.name
        else:
            raise click.UsageError(
                'No line segmentation given. Add one with `-l` or run `segment` first.'
            )
    elif no_segmentation:
        logger.warning(
            'no_segmentation mode enabled but segmentation defined. Ignoring --no-segmentation option.'
        )

    with open_file(lines, 'r') as fp:
        try:
            fp = cast(IO[Any], fp)
            bounds = json.load(fp)
        except ValueError as e:
            raise click.UsageError('{} invalid segmentation: {}'.format(
                lines, str(e)))
        # script detection
        if bounds['script_detection']:
            for l in bounds['boxes']:
                for t in l:
                    scripts.add(t[0])
            it = rpred.mm_rpred(model,
                                im,
                                bounds,
                                pad,
                                bidi_reordering=bidi_reordering,
                                script_ignore=script_ignore)
        else:
            it = rpred.rpred(model['default'],
                             im,
                             bounds,
                             pad,
                             bidi_reordering=bidi_reordering)

    if not lines and no_segmentation:
        logger.debug('Removing temporary segmentation file.')
        os.unlink(lines.name)

    preds = []

    with log.progressbar(it, label='Processing',
                         length=len(bounds['boxes'])) as bar:
        for pred in bar:
            preds.append(pred)

    ctx = click.get_current_context()
    with open_file(output, 'w', encoding='utf-8') as fp:
        fp = cast(IO[Any], fp)
        message('Writing recognition results for {}\t'.format(base_image),
                nl=False)
        logger.info('Serializing as {} into {}'.format(ctx.meta['mode'],
                                                       output))
        if ctx.meta['mode'] != 'text':
            from kraken import serialization
            fp.write(
                serialization.serialize(preds, base_image,
                                        Image.open(base_image).size,
                                        ctx.meta['text_direction'], scripts,
                                        ctx.meta['mode']))
        else:
            fp.write('\n'.join(s.prediction for s in preds))
        message('\u2713', fg='green')
Esempio n. 15
0
def train(ctx, pad, output, spec, append, load, savefreq, report, quit, epochs,
          lag, min_delta, device, optimizer, lrate, momentum, weight_decay,
          schedule, partition, normalization, codec, resize, reorder,
          training_files, evaluation_files, preload, threads, ground_truth):
    """
    Trains a model from image-text pairs.
    """
    if not load and append:
        raise click.BadOptionUsage(
            'append', 'append option requires loading an existing model')

    if resize != 'fail' and not load:
        raise click.BadOptionUsage(
            'resize', 'resize option requires loading an existing model')

    import re
    import torch
    import shutil
    import numpy as np

    from torch.utils.data import DataLoader

    from kraken.lib import models, vgsl, train
    from kraken.lib.util import make_printable
    from kraken.lib.train import EarlyStopping, EpochStopping, TrainStopper, TrainScheduler, add_1cycle
    from kraken.lib.codec import PytorchCodec
    from kraken.lib.dataset import GroundTruthDataset, compute_error, generate_input_transforms

    logger.info('Building ground truth set from {} line images'.format(
        len(ground_truth) + len(training_files)))

    # load model if given. if a new model has to be created we need to do that
    # after data set initialization, otherwise to output size is still unknown.
    nn = None
    if load:
        logger.info('Loading existing model from {} '.format(load))
        message('Loading model {}'.format(load), nl=False)
        nn = vgsl.TorchVGSLModel.load_model(load)
        message('\u2713', fg='green', nl=False)

    # preparse input sizes from vgsl string to seed ground truth data set
    # sizes and dimension ordering.
    if not nn:
        spec = spec.strip()
        if spec[0] != '[' or spec[-1] != ']':
            raise click.BadOptionUsage(
                'spec', 'VGSL spec {} not bracketed'.format(spec))
        blocks = spec[1:-1].split(' ')
        m = re.match(r'(\d+),(\d+),(\d+),(\d+)', blocks[0])
        if not m:
            raise click.BadOptionUsage(
                'spec', 'Invalid input spec {}'.format(blocks[0]))
        batch, height, width, channels = [int(x) for x in m.groups()]
    else:
        batch, channels, height, width = nn.input
    try:
        transforms = generate_input_transforms(batch, height, width, channels,
                                               pad)
    except KrakenInputException as e:
        raise click.BadOptionUsage('spec', str(e))

    # disable automatic partition when given evaluation set explicitly
    if evaluation_files:
        partition = 1
    ground_truth = list(ground_truth)

    # merge training_files into ground_truth list
    if training_files:
        ground_truth.extend(training_files)

    if len(ground_truth) == 0:
        raise click.UsageError(
            'No training data was provided to the train command. Use `-t` or the `ground_truth` argument.'
        )

    np.random.shuffle(ground_truth)

    if len(ground_truth) > 2500 and not preload:
        logger.info(
            'Disabling preloading for large (>2500) training data set. Enable by setting --preload parameter'
        )
        preload = False
    # implicit preloading enabled for small data sets
    if preload is None:
        preload = True

    tr_im = ground_truth[:int(len(ground_truth) * partition)]
    if evaluation_files:
        logger.debug('Using {} lines from explicit eval set'.format(
            len(evaluation_files)))
        te_im = evaluation_files
    else:
        te_im = ground_truth[int(len(ground_truth) * partition):]
        logger.debug('Taking {} lines from training for evaluation'.format(
            len(te_im)))

    gt_set = GroundTruthDataset(normalization=normalization,
                                reorder=reorder,
                                im_transforms=transforms,
                                preload=preload)
    with log.progressbar(tr_im, label='Building training set') as bar:
        for im in bar:
            logger.debug('Adding line {} to training set'.format(im))
            try:
                gt_set.add(im)
            except FileNotFoundError as e:
                logger.warning('{}: {}. Skipping.'.format(
                    e.strerror, e.filename))
            except KrakenInputException as e:
                logger.warning(str(e))

    val_set = GroundTruthDataset(normalization=normalization,
                                 reorder=reorder,
                                 im_transforms=transforms,
                                 preload=preload)
    with log.progressbar(te_im, label='Building validation set') as bar:
        for im in bar:
            logger.debug('Adding line {} to validation set'.format(im))
            try:
                val_set.add(im)
            except FileNotFoundError as e:
                logger.warning('{}: {}. Skipping.'.format(
                    e.strerror, e.filename))
            except KrakenInputException as e:
                logger.warning(str(e))

    logger.info(
        'Training set {} lines, validation set {} lines, alphabet {} symbols'.
        format(len(gt_set._images), len(val_set._images),
               len(gt_set.alphabet)))
    alpha_diff = set(gt_set.alphabet).symmetric_difference(
        set(val_set.alphabet))
    if alpha_diff:
        logger.warn('alphabet mismatch {}'.format(alpha_diff))
    logger.info('grapheme\tcount')
    for k, v in sorted(gt_set.alphabet.items(),
                       key=lambda x: x[1],
                       reverse=True):
        char = make_printable(k)
        if char == k:
            char = '\t' + char
        logger.info(u'{}\t{}'.format(char, v))

    logger.debug('Encoding training set')

    # use model codec when given
    if append:
        # is already loaded
        nn = cast(vgsl.TorchVGSLModel, nn)
        gt_set.encode(codec)
        message('Slicing and dicing model ', nl=False)
        # now we can create a new model
        spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label() + 1)
        logger.info('Appending {} to existing model {} after {}'.format(
            spec, nn.spec, append))
        nn.append(append, spec)
        nn.add_codec(gt_set.codec)
        message('\u2713', fg='green')
        logger.info('Assembled model spec: {}'.format(nn.spec))
    elif load:
        # is already loaded
        nn = cast(vgsl.TorchVGSLModel, nn)

        # prefer explicitly given codec over network codec if mode is 'both'
        codec = codec if (codec and resize == 'both') else nn.codec

        try:
            gt_set.encode(codec)
        except KrakenEncodeException as e:
            message('Network codec not compatible with training set')
            alpha_diff = set(gt_set.alphabet).difference(set(codec.c2l.keys()))
            if resize == 'fail':
                logger.error(
                    'Training data and model codec alphabets mismatch: {}'.
                    format(alpha_diff))
                ctx.exit(code=1)
            elif resize == 'add':
                message('Adding missing labels to network ', nl=False)
                logger.info(
                    'Resizing codec to include {} new code points'.format(
                        len(alpha_diff)))
                codec.c2l.update({
                    k: [v]
                    for v, k in enumerate(alpha_diff,
                                          start=codec.max_label() + 1)
                })
                nn.add_codec(PytorchCodec(codec.c2l))
                logger.info(
                    'Resizing last layer in network to {} outputs'.format(
                        codec.max_label() + 1))
                nn.resize_output(codec.max_label() + 1)
                message('\u2713', fg='green')
            elif resize == 'both':
                message('Fitting network exactly to training set ', nl=False)
                logger.info(
                    'Resizing network or given codec to {} code sequences'.
                    format(len(gt_set.alphabet)))
                gt_set.encode(None)
                ncodec, del_labels = codec.merge(gt_set.codec)
                logger.info(
                    'Deleting {} output classes from network ({} retained)'.
                    format(len(del_labels),
                           len(codec) - len(del_labels)))
                gt_set.encode(ncodec)
                nn.resize_output(ncodec.max_label() + 1, del_labels)
                message('\u2713', fg='green')
            else:
                raise click.BadOptionUsage(
                    'resize', 'Invalid resize value {}'.format(resize))
    else:
        gt_set.encode(codec)
        logger.info('Creating new model {} with {} outputs'.format(
            spec,
            gt_set.codec.max_label() + 1))
        spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label() + 1)
        nn = vgsl.TorchVGSLModel(spec)
        # initialize weights
        message('Initializing model ', nl=False)
        nn.init_weights()
        nn.add_codec(gt_set.codec)
        # initialize codec
        message('\u2713', fg='green')

    train_loader = DataLoader(gt_set,
                              batch_size=1,
                              shuffle=True,
                              pin_memory=True)

    # don't encode validation set as the alphabets may not match causing encoding failures
    val_set.training_set = list(zip(val_set._images, val_set._gt))

    logger.debug('Constructing {} optimizer (lr: {}, momentum: {})'.format(
        optimizer, lrate, momentum))

    # set mode to trainindg
    nn.train()

    # set number of OpenMP threads
    logger.debug('Set OpenMP threads to {}'.format(threads))
    nn.set_num_threads(threads)

    logger.debug('Moving model to device {}'.format(device))
    rec = models.TorchSeqRecognizer(nn, train=True, device=device)
    optim = getattr(torch.optim, optimizer)(nn.nn.parameters(), lr=0)

    tr_it = TrainScheduler(optim)
    if schedule == '1cycle':
        add_1cycle(tr_it, epochs * len(gt_set), lrate, momentum,
                   momentum - 0.10, weight_decay)
    else:
        # constant learning rate scheduler
        tr_it.add_phase(1, (lrate, lrate), (momentum, momentum), weight_decay,
                        train.annealing_const)

    st_it = cast(TrainStopper, None)  # type: TrainStopper
    if quit == 'early':
        st_it = EarlyStopping(train_loader, min_delta, lag)
    elif quit == 'dumb':
        st_it = EpochStopping(train_loader, epochs)
    else:
        raise click.BadOptionUsage(
            'quit', 'Invalid training interruption scheme {}'.format(quit))

    for epoch, loader in enumerate(st_it):
        with log.progressbar(label='epoch {}/{}'.format(
                epoch, epochs - 1 if epochs > 0 else '∞'),
                             length=len(loader),
                             show_pos=True) as bar:
            acc_loss = torch.tensor(0.0).to(device, non_blocking=True)
            for trial, (input, target) in enumerate(loader):
                tr_it.step()
                input = input.to(device, non_blocking=True)
                target = target.to(device, non_blocking=True)
                input = input.requires_grad_()
                o = nn.nn(input)
                # height should be 1 by now
                if o.size(2) != 1:
                    raise KrakenInputException(
                        'Expected dimension 3 to be 1, actual {}'.format(
                            o.size(2)))
                o = o.squeeze(2)
                optim.zero_grad()
                # NCW -> WNC
                loss = nn.criterion(
                    o.permute(2, 0, 1),  # type: ignore
                    target,
                    (o.size(2), ),
                    (target.size(1), ))
                logger.info('trial {}'.format(trial))
                if not torch.isinf(loss):
                    loss.backward()
                    optim.step()
                else:
                    logger.debug('infinite loss in trial {}'.format(trial))
                bar.update(1)
        if not epoch % savefreq:
            logger.info('Saving to {}_{}'.format(output, epoch))
            try:
                nn.save_model('{}_{}.mlmodel'.format(output, epoch))
            except Exception as e:
                logger.error('Saving model failed: {}'.format(str(e)))
        if not epoch % report:
            logger.debug('Starting evaluation run')
            nn.eval()
            chars, error = compute_error(rec, list(val_set))
            nn.train()
            accuracy = (chars - error) / chars
            logger.info('Accuracy report ({}) {:0.4f} {} {}'.format(
                epoch, accuracy, chars, error))
            message('Accuracy report ({}) {:0.4f} {} {}'.format(
                epoch, accuracy, chars, error))
            st_it.update(accuracy)
    if quit == 'early':
        message('Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'.
                format(output, st_it.best_epoch, st_it.best_loss))
        logger.info(
            'Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'.
            format(output, st_it.best_epoch, st_it.best_loss))
        shutil.copy('{}_{}.mlmodel'.format(output, st_it.best_epoch),
                    '{}_best.mlmodel'.format(output))
Esempio n. 16
0
def recognizer(model, pad, no_segmentation, bidi_reordering, script_ignore, base_image, input, output, lines) -> None:

    import json
    import tempfile

    from kraken import rpred

    try:
        im = Image.open(base_image)
    except IOError as e:
        raise click.BadParameter(str(e))

    ctx = click.get_current_context()

    # input may either be output from the segmenter then it is a JSON file or
    # be an image file when running the OCR subcommand alone. might still come
    # from some other subcommand though.
    scripts = set()
    if not lines and base_image != input:
        lines = input
    if not lines:
        if no_segmentation:
            lines = tempfile.NamedTemporaryFile(mode='w', delete=False)
            logger.info('Running in no_segmentation mode. Creating temporary segmentation {}.'.format(lines.name))
            json.dump({'script_detection': False,
                       'text_direction': 'horizontal-lr',
                       'boxes': [(0, 0) + im.size]}, lines)
            lines.close()
            lines = lines.name
        else:
            raise click.UsageError('No line segmentation given. Add one with `-l` or run `segment` first.')
    elif no_segmentation:
        logger.warning('no_segmentation mode enabled but segmentation defined. Ignoring --no-segmentation option.')

    with open_file(lines, 'r') as fp:
        try:
            fp = cast(IO[Any], fp)
            bounds = json.load(fp)
        except ValueError as e:
            raise click.UsageError('{} invalid segmentation: {}'.format(lines, str(e)))
        # script detection
        if bounds['script_detection']:
            for l in bounds['boxes']:
                for t in l:
                    scripts.add(t[0])
            it = rpred.mm_rpred(model, im, bounds, pad,
                                bidi_reordering=bidi_reordering,
                                script_ignore=script_ignore)
        else:
            it = rpred.rpred(model['default'], im, bounds, pad,
                             bidi_reordering=bidi_reordering)

    if not lines and no_segmentation:
        logger.debug('Removing temporary segmentation file.')
        os.unlink(lines.name)

    preds = []

    with log.progressbar(it, label='Processing', length=len(bounds['boxes'])) as bar:
        for pred in bar:
            preds.append(pred)

    ctx = click.get_current_context()
    with open_file(output, 'w', encoding='utf-8') as fp:
        fp = cast(IO[Any], fp)
        message('Writing recognition results for {}\t'.format(base_image), nl=False)
        logger.info('Serializing as {} into {}'.format(ctx.meta['mode'], output))
        if ctx.meta['mode'] != 'text':
            from kraken import serialization
            fp.write(serialization.serialize(preds, base_image,
                                             Image.open(base_image).size,
                                             ctx.meta['text_direction'],
                                             scripts,
                                             ctx.meta['mode']))
        else:
            fp.write('\n'.join(s.prediction for s in preds))
        message('\u2713', fg='green')