예제 #1
0
    def run(self,
            event_callback=lambda *args, **kwargs: None,
            iteration_callback=lambda *args, **kwargs: None):
        logger.debug('Starting up training...')
        # determine number of stages

        while self.stopper.trigger():
            for _, (input, target) in zip(range(self.event_it),
                                          self.train_set):
                if self.lr_scheduler:
                    self.lr_scheduler.step()
                input = input.to(self.device, non_blocking=True)
                target = target.to(self.device, non_blocking=True)
                input = input.requires_grad_()
                o = self.model.nn(input)
                # height should be 1 by now
                if o.size(2) != 1:
                    raise KrakenInputException(
                        'Expected dimension 3 to be 1, actual {}'.format(
                            o.size(2)))
                o = o.squeeze(2)
                self.optimizer.zero_grad()
                # NCW -> WNC
                loss = self.model.criterion(
                    o.permute(2, 0, 1),  # type: ignore
                    target,
                    (o.size(2), ),
                    (target.size(1), ))
                if not torch.isinf(loss):
                    loss.backward()
                    self.optimizer.step()
                else:
                    logger.debug('infinite loss in trial')
                iteration_callback()
            self.iterations += self.event_it
            logger.debug('Starting evaluation run')
            self.model.eval()
            chars, error = compute_error(self.rec, list(self.val_set))
            self.model.train()
            accuracy = (chars - error) / chars
            logger.info('Accuracy report ({}) {:0.4f} {} {}'.format(
                self.stopper.epoch, accuracy, chars, error))
            self.stopper.update(accuracy)
            self.model.user_metadata['accuracy'].append(
                (self.iterations, accuracy))
            logger.info('Saving to {}_{}'.format(self.filename_prefix,
                                                 self.stopper.epoch))
            event_callback(epoch=self.stopper.epoch,
                           accuracy=accuracy,
                           chars=chars,
                           error=error)
            try:
                self.model.user_metadata[
                    'completed_epochs'] = self.stopper.epoch
                self.model.save_model('{}_{}.mlmodel'.format(
                    self.filename_prefix, self.stopper.epoch))
            except Exception as e:
                logger.error('Saving model failed: {}'.format(str(e)))
예제 #2
0
파일: train.py 프로젝트: eighttails/kraken
def recognition_evaluator_fn(model, val_loader, device):
    rec = models.TorchSeqRecognizer(model, device=device)
    chars, error = compute_error(rec, val_loader)
    model.train()
    accuracy = ((chars - error).float() / chars).item()
    return {
        'val_metric': accuracy,
        'accuracy': accuracy,
        'chars': chars,
        'error': error
    }
예제 #3
0
파일: train.py 프로젝트: millawell/kraken
def recognition_evaluator_fn(model, val_set, device):
    rec = models.TorchSeqRecognizer(model, device=device)
    chars, error = compute_error(rec, list(val_set))
    model.train()
    accuracy = (chars-error)/chars
    return {'val_metric': accuracy, 'accuracy': accuracy, 'chars': chars, 'error': error}
예제 #4
0
파일: ketos.py 프로젝트: D-K-E/kraken
def train(ctx, pad, output, spec, append, load, savefreq, report, quit, epochs,
          lag, min_delta, device, optimizer, lrate, momentum, weight_decay,
          schedule, partition, normalization, codec, resize, reorder,
          training_files, evaluation_files, preload, threads, ground_truth):
    """
    Trains a model from image-text pairs.
    """
    if not load and append:
        raise click.BadOptionUsage(
            'append', 'append option requires loading an existing model')

    if resize != 'fail' and not load:
        raise click.BadOptionUsage(
            'resize', 'resize option requires loading an existing model')

    import re
    import torch
    import shutil
    import numpy as np

    from torch.utils.data import DataLoader

    from kraken.lib import models, vgsl, train
    from kraken.lib.util import make_printable
    from kraken.lib.train import EarlyStopping, EpochStopping, TrainStopper, TrainScheduler, add_1cycle
    from kraken.lib.codec import PytorchCodec
    from kraken.lib.dataset import GroundTruthDataset, compute_error, generate_input_transforms

    logger.info('Building ground truth set from {} line images'.format(
        len(ground_truth) + len(training_files)))

    # load model if given. if a new model has to be created we need to do that
    # after data set initialization, otherwise to output size is still unknown.
    nn = None
    if load:
        logger.info('Loading existing model from {} '.format(load))
        message('Loading model {}'.format(load), nl=False)
        nn = vgsl.TorchVGSLModel.load_model(load)
        message('\u2713', fg='green', nl=False)

    # preparse input sizes from vgsl string to seed ground truth data set
    # sizes and dimension ordering.
    if not nn:
        spec = spec.strip()
        if spec[0] != '[' or spec[-1] != ']':
            raise click.BadOptionUsage(
                'spec', 'VGSL spec {} not bracketed'.format(spec))
        blocks = spec[1:-1].split(' ')
        m = re.match(r'(\d+),(\d+),(\d+),(\d+)', blocks[0])
        if not m:
            raise click.BadOptionUsage(
                'spec', 'Invalid input spec {}'.format(blocks[0]))
        batch, height, width, channels = [int(x) for x in m.groups()]
    else:
        batch, channels, height, width = nn.input
    try:
        transforms = generate_input_transforms(batch, height, width, channels,
                                               pad)
    except KrakenInputException as e:
        raise click.BadOptionUsage('spec', str(e))

    # disable automatic partition when given evaluation set explicitly
    if evaluation_files:
        partition = 1
    ground_truth = list(ground_truth)

    # merge training_files into ground_truth list
    if training_files:
        ground_truth.extend(training_files)

    if len(ground_truth) == 0:
        raise click.UsageError(
            'No training data was provided to the train command. Use `-t` or the `ground_truth` argument.'
        )

    np.random.shuffle(ground_truth)

    if len(ground_truth) > 2500 and not preload:
        logger.info(
            'Disabling preloading for large (>2500) training data set. Enable by setting --preload parameter'
        )
        preload = False
    # implicit preloading enabled for small data sets
    if preload is None:
        preload = True

    tr_im = ground_truth[:int(len(ground_truth) * partition)]
    if evaluation_files:
        logger.debug('Using {} lines from explicit eval set'.format(
            len(evaluation_files)))
        te_im = evaluation_files
    else:
        te_im = ground_truth[int(len(ground_truth) * partition):]
        logger.debug('Taking {} lines from training for evaluation'.format(
            len(te_im)))

    gt_set = GroundTruthDataset(normalization=normalization,
                                reorder=reorder,
                                im_transforms=transforms,
                                preload=preload)
    with log.progressbar(tr_im, label='Building training set') as bar:
        for im in bar:
            logger.debug('Adding line {} to training set'.format(im))
            try:
                gt_set.add(im)
            except FileNotFoundError as e:
                logger.warning('{}: {}. Skipping.'.format(
                    e.strerror, e.filename))
            except KrakenInputException as e:
                logger.warning(str(e))

    val_set = GroundTruthDataset(normalization=normalization,
                                 reorder=reorder,
                                 im_transforms=transforms,
                                 preload=preload)
    with log.progressbar(te_im, label='Building validation set') as bar:
        for im in bar:
            logger.debug('Adding line {} to validation set'.format(im))
            try:
                val_set.add(im)
            except FileNotFoundError as e:
                logger.warning('{}: {}. Skipping.'.format(
                    e.strerror, e.filename))
            except KrakenInputException as e:
                logger.warning(str(e))

    logger.info(
        'Training set {} lines, validation set {} lines, alphabet {} symbols'.
        format(len(gt_set._images), len(val_set._images),
               len(gt_set.alphabet)))
    alpha_diff = set(gt_set.alphabet).symmetric_difference(
        set(val_set.alphabet))
    if alpha_diff:
        logger.warn('alphabet mismatch {}'.format(alpha_diff))
    logger.info('grapheme\tcount')
    for k, v in sorted(gt_set.alphabet.items(),
                       key=lambda x: x[1],
                       reverse=True):
        char = make_printable(k)
        if char == k:
            char = '\t' + char
        logger.info(u'{}\t{}'.format(char, v))

    logger.debug('Encoding training set')

    # use model codec when given
    if append:
        # is already loaded
        nn = cast(vgsl.TorchVGSLModel, nn)
        gt_set.encode(codec)
        message('Slicing and dicing model ', nl=False)
        # now we can create a new model
        spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label() + 1)
        logger.info('Appending {} to existing model {} after {}'.format(
            spec, nn.spec, append))
        nn.append(append, spec)
        nn.add_codec(gt_set.codec)
        message('\u2713', fg='green')
        logger.info('Assembled model spec: {}'.format(nn.spec))
    elif load:
        # is already loaded
        nn = cast(vgsl.TorchVGSLModel, nn)

        # prefer explicitly given codec over network codec if mode is 'both'
        codec = codec if (codec and resize == 'both') else nn.codec

        try:
            gt_set.encode(codec)
        except KrakenEncodeException as e:
            message('Network codec not compatible with training set')
            alpha_diff = set(gt_set.alphabet).difference(set(codec.c2l.keys()))
            if resize == 'fail':
                logger.error(
                    'Training data and model codec alphabets mismatch: {}'.
                    format(alpha_diff))
                ctx.exit(code=1)
            elif resize == 'add':
                message('Adding missing labels to network ', nl=False)
                logger.info(
                    'Resizing codec to include {} new code points'.format(
                        len(alpha_diff)))
                codec.c2l.update({
                    k: [v]
                    for v, k in enumerate(alpha_diff,
                                          start=codec.max_label() + 1)
                })
                nn.add_codec(PytorchCodec(codec.c2l))
                logger.info(
                    'Resizing last layer in network to {} outputs'.format(
                        codec.max_label() + 1))
                nn.resize_output(codec.max_label() + 1)
                message('\u2713', fg='green')
            elif resize == 'both':
                message('Fitting network exactly to training set ', nl=False)
                logger.info(
                    'Resizing network or given codec to {} code sequences'.
                    format(len(gt_set.alphabet)))
                gt_set.encode(None)
                ncodec, del_labels = codec.merge(gt_set.codec)
                logger.info(
                    'Deleting {} output classes from network ({} retained)'.
                    format(len(del_labels),
                           len(codec) - len(del_labels)))
                gt_set.encode(ncodec)
                nn.resize_output(ncodec.max_label() + 1, del_labels)
                message('\u2713', fg='green')
            else:
                raise click.BadOptionUsage(
                    'resize', 'Invalid resize value {}'.format(resize))
    else:
        gt_set.encode(codec)
        logger.info('Creating new model {} with {} outputs'.format(
            spec,
            gt_set.codec.max_label() + 1))
        spec = '[{} O1c{}]'.format(spec[1:-1], gt_set.codec.max_label() + 1)
        nn = vgsl.TorchVGSLModel(spec)
        # initialize weights
        message('Initializing model ', nl=False)
        nn.init_weights()
        nn.add_codec(gt_set.codec)
        # initialize codec
        message('\u2713', fg='green')

    train_loader = DataLoader(gt_set,
                              batch_size=1,
                              shuffle=True,
                              pin_memory=True)

    # don't encode validation set as the alphabets may not match causing encoding failures
    val_set.training_set = list(zip(val_set._images, val_set._gt))

    logger.debug('Constructing {} optimizer (lr: {}, momentum: {})'.format(
        optimizer, lrate, momentum))

    # set mode to trainindg
    nn.train()

    # set number of OpenMP threads
    logger.debug('Set OpenMP threads to {}'.format(threads))
    nn.set_num_threads(threads)

    logger.debug('Moving model to device {}'.format(device))
    rec = models.TorchSeqRecognizer(nn, train=True, device=device)
    optim = getattr(torch.optim, optimizer)(nn.nn.parameters(), lr=0)

    tr_it = TrainScheduler(optim)
    if schedule == '1cycle':
        add_1cycle(tr_it, epochs * len(gt_set), lrate, momentum,
                   momentum - 0.10, weight_decay)
    else:
        # constant learning rate scheduler
        tr_it.add_phase(1, (lrate, lrate), (momentum, momentum), weight_decay,
                        train.annealing_const)

    st_it = cast(TrainStopper, None)  # type: TrainStopper
    if quit == 'early':
        st_it = EarlyStopping(train_loader, min_delta, lag)
    elif quit == 'dumb':
        st_it = EpochStopping(train_loader, epochs)
    else:
        raise click.BadOptionUsage(
            'quit', 'Invalid training interruption scheme {}'.format(quit))

    for epoch, loader in enumerate(st_it):
        with log.progressbar(label='epoch {}/{}'.format(
                epoch, epochs - 1 if epochs > 0 else '∞'),
                             length=len(loader),
                             show_pos=True) as bar:
            acc_loss = torch.tensor(0.0).to(device, non_blocking=True)
            for trial, (input, target) in enumerate(loader):
                tr_it.step()
                input = input.to(device, non_blocking=True)
                target = target.to(device, non_blocking=True)
                input = input.requires_grad_()
                o = nn.nn(input)
                # height should be 1 by now
                if o.size(2) != 1:
                    raise KrakenInputException(
                        'Expected dimension 3 to be 1, actual {}'.format(
                            o.size(2)))
                o = o.squeeze(2)
                optim.zero_grad()
                # NCW -> WNC
                loss = nn.criterion(
                    o.permute(2, 0, 1),  # type: ignore
                    target,
                    (o.size(2), ),
                    (target.size(1), ))
                logger.info('trial {}'.format(trial))
                if not torch.isinf(loss):
                    loss.backward()
                    optim.step()
                else:
                    logger.debug('infinite loss in trial {}'.format(trial))
                bar.update(1)
        if not epoch % savefreq:
            logger.info('Saving to {}_{}'.format(output, epoch))
            try:
                nn.save_model('{}_{}.mlmodel'.format(output, epoch))
            except Exception as e:
                logger.error('Saving model failed: {}'.format(str(e)))
        if not epoch % report:
            logger.debug('Starting evaluation run')
            nn.eval()
            chars, error = compute_error(rec, list(val_set))
            nn.train()
            accuracy = (chars - error) / chars
            logger.info('Accuracy report ({}) {:0.4f} {} {}'.format(
                epoch, accuracy, chars, error))
            message('Accuracy report ({}) {:0.4f} {} {}'.format(
                epoch, accuracy, chars, error))
            st_it.update(accuracy)
    if quit == 'early':
        message('Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'.
                format(output, st_it.best_epoch, st_it.best_loss))
        logger.info(
            'Moving best model {0}_{1}.mlmodel ({2}) to {0}_best.mlmodel'.
            format(output, st_it.best_epoch, st_it.best_loss))
        shutil.copy('{}_{}.mlmodel'.format(output, st_it.best_epoch),
                    '{}_best.mlmodel'.format(output))