Ejemplo n.º 1
0
def train(ctx, lineheight, pad, hiddensize, output, load, savefreq, report,
          ntrain, lrate, momentum, partition, normalization, reorder,
          ground_truth):
    """
    Trains a model from image-text pairs.
    """
    st_time = time.time()
    if ctx.meta['verbose'] > 0:
        click.echo(
            u'[{:2.4f}] Building ground truth set from {} line images'.format(
                time.time() - st_time, len(ground_truth)))
    else:
        spin('Building ground truth set')

    gt_set = GroundTruthContainer()

    for line in ground_truth:
        gt_set.add(line, normalization=normalization, reorder=reorder)
        if ctx.meta['verbose'] > 2:
            click.echo(u'[{:2.4f}] Adding {}'.format(time.time() - st_time,
                                                     line))
        else:
            spin('Building ground truth set')
    gt_set.repartition(partition)
    if ctx.meta['verbose'] < 3:
        click.echo('')
    if ctx.meta['verbose'] > 0:
        click.echo(
            u'[{:2.4f}] Training set {} lines, test set {} lines, alphabet {} symbols'
            .format(time.time() - st_time, len(gt_set.training_set),
                    len(gt_set.test_set), len(gt_set.training_alphabet)))
    if ctx.meta['verbose'] > 1:
        click.echo(u'[{:2.4f}] grapheme\tcount'.format(time.time() - st_time))
        for k, v in sorted(gt_set.training_alphabet.iteritems(),
                           key=lambda (x): x[1],
                           reverse=True):
            if unicodedata.combining(k) or k.isspace():
                k = unicodedata.name(k)
            else:
                k = '\t' + k
            click.echo(u'[{:2.4f}] {}\t{}'.format(time.time() - st_time, k, v))

    if not ctx.meta['verbose']:
        click.secho(u'\b\u2713', fg='green', nl=False)
        click.echo('\033[?25h\n', nl=False)

    if load:
        if ctx.meta['verbose'] > 0:
            click.echo(u'[{:2.4f}] Loading existing model from {} '.format(
                time.time() - st_time, load))
        else:
            spin('Loading model')

        rnn = models.ClstmSeqRecognizer(load)

        if not ctx.meta['verbose'] > 0:
            click.secho(u'\b\u2713', fg='green', nl=False)
            click.echo('\033[?25h\n', nl=False)

    else:
        if ctx.meta['verbose'] > 0:
            click.echo(
                u'[{:2.4f}] Creating new model with line height {}, {} hidden units, and {} outputs'
                .format(time.time() - st_time, lineheight, hiddensize, codec))
        else:
            spin('Initializing model')
        rnn = models.ClstmSeqRecognizer.init_model(
            lineheight, hiddensize, gt_set.training_alphabet.keys())
        if not ctx.meta['verbose']:
            click.secho(u'\b\u2713', fg='green', nl=False)
            click.echo('\033[?25h\n', nl=False)

    if ctx.meta['verbose'] > 0:
        click.echo(
            u'[{:2.4f}] Setting learning rate ({}) and momentum ({}) '.format(
                time.time() - st_time, lrate, momentum))
    rnn.setLearningRate(lrate, momentum)

    for trial in xrange(ntrain):
        line, s = gt_set.sample()
        res = rnn.trainString(line, s)
        if ctx.meta['verbose'] > 2:
            click.echo(u'[{0:2.4f}] TRU: {1}\n[{0:2.4f}] OUT: {2}'.format(
                time.time() - st_time, s, res))
        else:
            spin('Training')

        if trial and not trial % savefreq:
            rnn.save_model('{}_{}'.format(output, trial))
            if ctx.meta['verbose'] < 3:
                click.echo('')
            if ctx.meta['verbose'] > 0:
                click.echo(u'[{:2.4f}] Saving to {}_{}'.format(
                    time.time() - st_time, output, trial))

        if trial and not trial % report:
            c, e = compute_error(rnn, gt_set.test_set)
            if ctx.meta['verbose'] < 3:
                click.echo('')
            click.echo(u'[{:2.4f}] Accuracy report ({}) {:0.4f} {} {}'.format(
                time.time() - st_time, trial, (c - e) / c, c, e))
Ejemplo n.º 2
0
def train(ctx, lineheight, pad, hiddensize, output, load, savefreq, report,
          ntrain, lrate, momentum, partition, normalization, reorder,
          ground_truth):
    """
    Trains a model from image-text pairs.
    """
    if load is None:
        message(u'Training from scratch net yet supported.')
        ctx.exit(1)
    logger.info(u'Building ground truth set from {} line images'.format(
        len(ground_truth)))
    spin(u'Building ground truth set')

    gt_set = GroundTruthContainer()

    for line in ground_truth:
        gt_set.add(line, normalization=normalization, reorder=reorder)
        logger.debug(u'Adding {}'.format(line))
        spin(u'Building ground truth set')
    gt_set.repartition(partition)

    logger.info(
        u'Training set {} lines, test set {} lines, alphabet {} symbols'.
        format(len(gt_set.training_set), len(gt_set.test_set),
               len(gt_set.training_alphabet)))
    logger.debug(u'grapheme\tcount')
    for k, v in sorted(gt_set.training_alphabet.items(),
                       key=lambda x: x[1],
                       reverse=True):
        if unicodedata.combining(k) or k.isspace():
            k = unicodedata.name(k)
        else:
            k = '\t' + k
        logger.debug(u'{}\t{}'.format(k, v))

    message(u'\b\u2713', fg='green', nl=False)
    message('\033[?25h\n', nl=False)

    if load:
        logger.info(u'Loading existing model from {} '.format(load))
        spin('Loading model')

        rnn = models.ClstmSeqRecognizer(load)

        message(u'\b\u2713', fg='green', nl=False)
        message('\033[?25h\n', nl=False)

    else:
        ctx.exit(1)

    logger.info(u'Setting learning rate ({}) and momentum ({}) '.format(
        lrate, momentum))
    rnn.setLearningRate(lrate, momentum)

    for trial in xrange(ntrain):
        line, s = gt_set.sample()
        res = rnn.trainString(line, s)
        logger.debug(u'TRU: {1}\n[{0:2.4f}] OUT: {2}'.format(s, res))
        spin('Training')

        if trial and not trial % savefreq:
            rnn.save_model('{}_{}'.format(output, trial))
            logger.info(u'Saving to {}_{}'.format(output, trial))

        if trial and not trial % report:
            c, e = compute_error(rnn, gt_set.test_set)
            logger.info(u'Accuracy report ({}) {:0.4f} {} {}'.format(
                trial, (c - e) / c, c, e))