def train(ctx, lineheight, pad, hiddensize, output, load, savefreq, report, ntrain, lrate, momentum, partition, normalization, reorder, ground_truth): """ Trains a model from image-text pairs. """ st_time = time.time() if ctx.meta['verbose'] > 0: click.echo( u'[{:2.4f}] Building ground truth set from {} line images'.format( time.time() - st_time, len(ground_truth))) else: spin('Building ground truth set') gt_set = GroundTruthContainer() for line in ground_truth: gt_set.add(line, normalization=normalization, reorder=reorder) if ctx.meta['verbose'] > 2: click.echo(u'[{:2.4f}] Adding {}'.format(time.time() - st_time, line)) else: spin('Building ground truth set') gt_set.repartition(partition) if ctx.meta['verbose'] < 3: click.echo('') if ctx.meta['verbose'] > 0: click.echo( u'[{:2.4f}] Training set {} lines, test set {} lines, alphabet {} symbols' .format(time.time() - st_time, len(gt_set.training_set), len(gt_set.test_set), len(gt_set.training_alphabet))) if ctx.meta['verbose'] > 1: click.echo(u'[{:2.4f}] grapheme\tcount'.format(time.time() - st_time)) for k, v in sorted(gt_set.training_alphabet.iteritems(), key=lambda (x): x[1], reverse=True): if unicodedata.combining(k) or k.isspace(): k = unicodedata.name(k) else: k = '\t' + k click.echo(u'[{:2.4f}] {}\t{}'.format(time.time() - st_time, k, v)) if not ctx.meta['verbose']: click.secho(u'\b\u2713', fg='green', nl=False) click.echo('\033[?25h\n', nl=False) if load: if ctx.meta['verbose'] > 0: click.echo(u'[{:2.4f}] Loading existing model from {} '.format( time.time() - st_time, load)) else: spin('Loading model') rnn = models.ClstmSeqRecognizer(load) if not ctx.meta['verbose'] > 0: click.secho(u'\b\u2713', fg='green', nl=False) click.echo('\033[?25h\n', nl=False) else: if ctx.meta['verbose'] > 0: click.echo( u'[{:2.4f}] Creating new model with line height {}, {} hidden units, and {} outputs' .format(time.time() - st_time, lineheight, hiddensize, codec)) else: spin('Initializing model') rnn = models.ClstmSeqRecognizer.init_model( lineheight, hiddensize, gt_set.training_alphabet.keys()) if not ctx.meta['verbose']: click.secho(u'\b\u2713', fg='green', nl=False) click.echo('\033[?25h\n', nl=False) if ctx.meta['verbose'] > 0: click.echo( u'[{:2.4f}] Setting learning rate ({}) and momentum ({}) '.format( time.time() - st_time, lrate, momentum)) rnn.setLearningRate(lrate, momentum) for trial in xrange(ntrain): line, s = gt_set.sample() res = rnn.trainString(line, s) if ctx.meta['verbose'] > 2: click.echo(u'[{0:2.4f}] TRU: {1}\n[{0:2.4f}] OUT: {2}'.format( time.time() - st_time, s, res)) else: spin('Training') if trial and not trial % savefreq: rnn.save_model('{}_{}'.format(output, trial)) if ctx.meta['verbose'] < 3: click.echo('') if ctx.meta['verbose'] > 0: click.echo(u'[{:2.4f}] Saving to {}_{}'.format( time.time() - st_time, output, trial)) if trial and not trial % report: c, e = compute_error(rnn, gt_set.test_set) if ctx.meta['verbose'] < 3: click.echo('') click.echo(u'[{:2.4f}] Accuracy report ({}) {:0.4f} {} {}'.format( time.time() - st_time, trial, (c - e) / c, c, e))
def train(ctx, lineheight, pad, hiddensize, output, load, savefreq, report, ntrain, lrate, momentum, partition, normalization, reorder, ground_truth): """ Trains a model from image-text pairs. """ if load is None: message(u'Training from scratch net yet supported.') ctx.exit(1) logger.info(u'Building ground truth set from {} line images'.format( len(ground_truth))) spin(u'Building ground truth set') gt_set = GroundTruthContainer() for line in ground_truth: gt_set.add(line, normalization=normalization, reorder=reorder) logger.debug(u'Adding {}'.format(line)) spin(u'Building ground truth set') gt_set.repartition(partition) logger.info( u'Training set {} lines, test set {} lines, alphabet {} symbols'. format(len(gt_set.training_set), len(gt_set.test_set), len(gt_set.training_alphabet))) logger.debug(u'grapheme\tcount') for k, v in sorted(gt_set.training_alphabet.items(), key=lambda x: x[1], reverse=True): if unicodedata.combining(k) or k.isspace(): k = unicodedata.name(k) else: k = '\t' + k logger.debug(u'{}\t{}'.format(k, v)) message(u'\b\u2713', fg='green', nl=False) message('\033[?25h\n', nl=False) if load: logger.info(u'Loading existing model from {} '.format(load)) spin('Loading model') rnn = models.ClstmSeqRecognizer(load) message(u'\b\u2713', fg='green', nl=False) message('\033[?25h\n', nl=False) else: ctx.exit(1) logger.info(u'Setting learning rate ({}) and momentum ({}) '.format( lrate, momentum)) rnn.setLearningRate(lrate, momentum) for trial in xrange(ntrain): line, s = gt_set.sample() res = rnn.trainString(line, s) logger.debug(u'TRU: {1}\n[{0:2.4f}] OUT: {2}'.format(s, res)) spin('Training') if trial and not trial % savefreq: rnn.save_model('{}_{}'.format(output, trial)) logger.info(u'Saving to {}_{}'.format(output, trial)) if trial and not trial % report: c, e = compute_error(rnn, gt_set.test_set) logger.info(u'Accuracy report ({}) {:0.4f} {} {}'.format( trial, (c - e) / c, c, e))