Example #1
0
def main():
    args = parse_args()
    if not any(args.loss == s for s in ['ctc', 'warpctc']):
        raise ValueError("Invalid loss '{}' (must be 'ctc' or 'warpctc')".format(args.loss))
    hp = Hyperparams()

    try:
        if args.resume:
            model_path,epoch=args.resume.split(",")
            _,arg_params, aux_params = mx.model.load_checkpoint(model_path,int(epoch))
        else:
            arg_params, aux_params = None,None

        if args.gpu:
            contexts = [mx.context.gpu(i) for i in range(args.gpu)]
        else:
            contexts = [mx.context.cpu(i) for i in range(args.cpu)]

        init_states = lstm.init_states(hp.batch_size, hp.num_lstm_layer, hp.num_hidden)

        data_train = OCRIter(
             hp.batch_size, init_states,hp.data_path,name='train')
        data_val = OCRIter(
             hp.batch_size, init_states,hp.data_path, name='val')

        if not os.path.exists('checkpoint'):
            os.makedirs('checkpoint')

        head = '%(asctime)-15s %(message)s'
        logging.basicConfig(level=logging.DEBUG, format=head)

        module = mx.mod.BucketingModule(
            context=contexts,
            sym_gen=lstm.sym_gen,
            default_bucket_key=max(hp.bucket_len),
            )


        metrics = CtcMetrics()
        module.fit(train_data=data_train,
                   eval_data=data_val,
                   eval_metric=mx.metric.np(metrics.accuracy, allow_extra_outputs=True),
                   optimizer='sgd',
                   optimizer_params={'learning_rate': hp.learning_rate,
                                     'momentum': hp.momentum,
                                     'wd': 0.00001,
                                     },
                   initializer=mx.init.Xavier(factor_type="in", magnitude=2.34),
                   arg_params=arg_params,
                   aux_params=aux_params,
                   num_epoch=hp.num_epoch,
                   batch_end_callback=mx.callback.Speedometer(hp.batch_size, 50),
                   epoch_end_callback=mx.callback.do_checkpoint(args.prefix,args.save_epoch),
                   )
    except KeyboardInterrupt:
        print("W: interrupt received, stopping...")
Example #2
0
 def predict_DataIter(self, img_path):
     '''
     super mx.io.DataIter and mod.forward to predict
     '''
     img = self._preprocess_image(img_path)
     img = Io_class(img)
     self.mod.forward(img)
     res = self.mod.get_outputs()[0].asnumpy()
     prediction = CtcMetrics.ctc_label(np.argmax(res, axis=-1).tolist())
     prediction = [p - 1 for p in prediction]
     return prediction
Example #3
0
 def predict(self, img_path):
     '''
     use mx.io.NDArrayIter and mod.predict to predict
     '''
     img = self._preprocess_image(img_path)
     img = mx.io.NDArrayIter(data=img, label=None, batch_size=1)
     res = self.mod.predict(eval_data=img, num_batch=1)
     res = res.asnumpy()
     prediction = CtcMetrics.ctc_label(np.argmax(res, axis=-1).tolist())
     prediction = [p - 1 for p in prediction]
     return prediction
Example #4
0
 def test(val_data, ctx):
     metric = CtcMetrics(num_classes=config.num_classes)
     metric.reset()
     for datas, labels in val_data:
         data = gluon.utils.split_and_load(nd.array(datas),
                                           ctx_list=ctx,
                                           batch_axis=0,
                                           even_split=False)
         label = gluon.utils.split_and_load(nd.array(labels),
                                            ctx_list=ctx,
                                            batch_axis=0,
                                            even_split=False)
         output = [net(X) for X in data]
         metric.update(label, output)
     return metric.get()
Example #5
0
    def train(ctx, batch_size):
        #net.initialize(mx.init.Xavier(), ctx=ctx)
        train_data = DataLoader(ImageDataset(root=default.dataset_path, train=True), \
                                batch_size=batch_size,shuffle=True,num_workers=num_workers)
        val_data = DataLoader(ImageDataset(root=default.dataset_path, train=False), \
                              batch_size=batch_size, shuffle=True,num_workers=num_workers)

        # lr_epoch = [int(epoch) for epoch in args.lr_step.split(',')]
        net.collect_params().reset_ctx(ctx)
        lr = args.lr
        end_lr = args.end_lr
        lr_decay = args.lr_decay
        lr_decay_step = args.lr_decay_step
        all_step = len(train_data)
        schedule = mx.lr_scheduler.FactorScheduler(step=lr_decay_step *
                                                   all_step,
                                                   factor=lr_decay,
                                                   stop_factor_lr=end_lr)
        adam_optimizer = mx.optimizer.Adam(learning_rate=lr,
                                           lr_scheduler=schedule)
        trainer = gluon.Trainer(net.collect_params(), optimizer=adam_optimizer)

        train_metric = CtcMetrics()
        train_history = TrainingHistory(['training-error', 'validation-error'])

        iteration = 0
        best_val_score = 0

        save_period = args.save_period
        save_dir = args.save_dir
        model_name = args.prefix
        plot_path = args.save_dir
        epochs = args.end_epoch
        frequent = args.frequent
        for epoch in range(epochs):
            tic = time.time()
            train_metric.reset()
            train_loss = 0
            num_batch = 0
            tic_b = time.time()
            for datas, labels in train_data:
                data = gluon.utils.split_and_load(nd.array(datas),
                                                  ctx_list=ctx,
                                                  batch_axis=0,
                                                  even_split=False)
                label = gluon.utils.split_and_load(nd.array(labels),
                                                   ctx_list=ctx,
                                                   batch_axis=0,
                                                   even_split=False)
                with ag.record():
                    output = [net(X) for X in data]
                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
                for l in loss:
                    l.backward()
                trainer.step(batch_size)
                train_loss += sum([l.sum().asscalar() for l in loss])

                train_metric.update(label, output)
                name, acc = train_metric.get()
                iteration += 1
                num_batch += 1
                if num_batch % frequent == 0:
                    train_loss_b = train_loss / (batch_size * num_batch)
                    logging.info(
                        '[Epoch %d] [num_bath %d] tain_acc=%f loss=%f time/batch: %f'
                        % (epoch, num_batch, acc, train_loss_b,
                           (time.time() - tic_b) / num_batch))
            train_loss /= batch_size * num_batch
            name, acc = train_metric.get()
            name, val_acc = test(val_data, ctx)
            train_history.update([1 - acc, 1 - val_acc])
            train_history.plot(save_path='%s/%s_history.png' %
                               (plot_path, model_name))
            if val_acc > best_val_score:
                best_val_score = val_acc
                net.save_parameters(
                    '%s/%.4f-crnn-%s-%d-best.params' %
                    (save_dir, best_val_score, model_name, epoch))
            logging.info('[Epoch %d] train=%f val=%f loss=%f time: %f' %
                         (epoch, acc, val_acc, train_loss, time.time() - tic))

            if save_period and save_dir and (epoch + 1) % save_period == 0:
                symbol_file = os.path.join(save_dir, model_name)
                net.export(path=symbol_file, epoch=epoch)
                # net.save_parameters('%s/crnn-%s-%d.params' % (save_dir, model_name, epoch))

        if save_period and save_dir:
            symbol_file = os.path.join(save_dir, model_name)
            net.export(path=symbol_file, epoch=epoch - 1)
def main():
    args = parse_args()
    # ctx = []
    #    cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip()
    #    if len(cvd)>0:
    #      for i in xrange(len(cvd.split(','))):
    #        ctx.append(mx.gpu(i))
    #    if len(ctx)==0:
    #      ctx = [mx.cpu()]
    #      print('use cpu')
    #    else:
    #      print('gpu num:', len(ctx))

    # ctx = [mx.gpu(0),mx.gpu(2),mx.gpu(4),mx.gpu(6)]
    ctx = [mx.gpu(0)]

    args.ctx_num = len(ctx)
    args.per_batch_size = args.batch_size // args.ctx_num
    # data_names = ['data'] + [x[0] for x in init_states]
    if config.use_lstm:
        init_c = [('l%d_init_c' % l, (args.batch_size, config.num_hidden))
                  for l in range(config.num_lstm_layer * 2)]
        init_h = [('l%d_init_h' % l, (args.batch_size, config.num_hidden))
                  for l in range(config.num_lstm_layer * 2)]
        init_states = init_c + init_h
        # data_names = ['data'] + [x[0] for x in init_states]

        train_iter = TextIter(dataset_path=args.dataset_path,
                              image_path=config.image_path,
                              image_set='train',
                              batch_size=args.batch_size,
                              init_states=init_states)
        val_iter = TextIter(dataset_path=args.dataset_path,
                            image_path=config.image_path,
                            image_set='test',
                            batch_size=args.batch_size,
                            init_states=init_states)
        # sym = crnn_lstm(args.network, args.per_batch_size)
    # else:#
    #     data_names = ['data']
    #     train_iter = TextIter(path=args.dataset_path, data_root=config.image_path, batch_size=args.batch_size,
    #                           num_label=100,init_states=init_states)
    #     val_iter = TextIter(path=args.dataset_path, data_root=config.image_path, batch_size=args.batch_size,
    #                           num_label=100,init_states=init_states)
    # sym = crnn_no_lstm(args.network, args.per_batch_size)

    # head = '%(asctime)-15s %(message)s'
    # logging.basicConfig(level=logging.DEBUG, format=head)

    metrics = CtcMetrics()

    # if args.network[0] == 'r' or args.network[0] == 'y':
    #     initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2)  # resnet style
    # elif args.network[0] == 'i' or args.network[0] == 'x':
    #     initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2)  # inception
    # else:
    #     initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2)
    initializer = mx.init.Xavier(factor_type="in", magnitude=2.34)
    _rescale = 1.0 / args.ctx_num
    base_lr = args.lr
    lr_factor = 0.5
    lr_epoch = [int(epoch) for epoch in args.lr_step.split(',')]
    lr_epoch_diff = [
        epoch - args.begin_epoch for epoch in lr_epoch
        if epoch > args.begin_epoch
    ]
    lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [
        int(epoch * train_iter.num_samples() / args.batch_size)
        for epoch in lr_epoch_diff
    ]
    logger.info('lr %f lr_epoch_diff %s lr_iters %s' %
                (lr, lr_epoch_diff, lr_iters))
    lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(lr_iters, lr_factor)
    if config.use_lstm:
        optimizer = 'AdaDelta'
        optimizer_params = {
            'wd': 0.00001,
            'learning_rate': base_lr,
            'lr_scheduler': lr_scheduler,
            'rescale_grad': (1.0 / args.ctx_num),
            'clip_gradient': None
        }
    else:
        optimizer = 'sgd'
        optimizer_params = {
            'momentum': 0.9,
            'wd': 0.0002,
            'learning_rate': base_lr,
            'lr_scheduler': lr_scheduler,
            'rescale_grad': (1.0 / args.ctx_num),
            'clip_gradient': None
        }
    if args.pretrained:
        sym, arg_params, aux_params = mx.model.load_checkpoint(
            args.pretrained, args.pretrained_epoch)

    else:
        arg_params = None
        aux_params = None
    if config.use_lstm:
        module = mx.mod.BucketingModule(
            sym_gen=crnn_lstm,
            default_bucket_key=train_iter.default_bucket_key,
            context=ctx)
    else:
        module = mx.mod.BucketingModule(
            sym_gen=crnn_no_lstm,
            default_bucket_key=train_iter.default_bucket_key,
            context=ctx)
    module.fit(
        train_data=train_iter,
        eval_data=val_iter,
        begin_epoch=args.begin_epoch,
        num_epoch=args.end_epoch,
        # allow_missing=True,
        # use metrics.accuracy or metrics.accuracy_lcs
        eval_metric=mx.metric.np(metrics.accuracy, allow_extra_outputs=True),
        optimizer=optimizer,
        optimizer_params=optimizer_params,
        initializer=initializer,
        arg_params=arg_params,
        aux_params=aux_params,
        batch_end_callback=mx.callback.Speedometer(args.batch_size,
                                                   args.frequent),
        epoch_end_callback=mx.callback.do_checkpoint(args.prefix, period=10),
    )