def main(): args = parse_args() if not any(args.loss == s for s in ['ctc', 'warpctc']): raise ValueError("Invalid loss '{}' (must be 'ctc' or 'warpctc')".format(args.loss)) hp = Hyperparams() try: if args.resume: model_path,epoch=args.resume.split(",") _,arg_params, aux_params = mx.model.load_checkpoint(model_path,int(epoch)) else: arg_params, aux_params = None,None if args.gpu: contexts = [mx.context.gpu(i) for i in range(args.gpu)] else: contexts = [mx.context.cpu(i) for i in range(args.cpu)] init_states = lstm.init_states(hp.batch_size, hp.num_lstm_layer, hp.num_hidden) data_train = OCRIter( hp.batch_size, init_states,hp.data_path,name='train') data_val = OCRIter( hp.batch_size, init_states,hp.data_path, name='val') if not os.path.exists('checkpoint'): os.makedirs('checkpoint') head = '%(asctime)-15s %(message)s' logging.basicConfig(level=logging.DEBUG, format=head) module = mx.mod.BucketingModule( context=contexts, sym_gen=lstm.sym_gen, default_bucket_key=max(hp.bucket_len), ) metrics = CtcMetrics() module.fit(train_data=data_train, eval_data=data_val, eval_metric=mx.metric.np(metrics.accuracy, allow_extra_outputs=True), optimizer='sgd', optimizer_params={'learning_rate': hp.learning_rate, 'momentum': hp.momentum, 'wd': 0.00001, }, initializer=mx.init.Xavier(factor_type="in", magnitude=2.34), arg_params=arg_params, aux_params=aux_params, num_epoch=hp.num_epoch, batch_end_callback=mx.callback.Speedometer(hp.batch_size, 50), epoch_end_callback=mx.callback.do_checkpoint(args.prefix,args.save_epoch), ) except KeyboardInterrupt: print("W: interrupt received, stopping...")
def predict_DataIter(self, img_path): ''' super mx.io.DataIter and mod.forward to predict ''' img = self._preprocess_image(img_path) img = Io_class(img) self.mod.forward(img) res = self.mod.get_outputs()[0].asnumpy() prediction = CtcMetrics.ctc_label(np.argmax(res, axis=-1).tolist()) prediction = [p - 1 for p in prediction] return prediction
def predict(self, img_path): ''' use mx.io.NDArrayIter and mod.predict to predict ''' img = self._preprocess_image(img_path) img = mx.io.NDArrayIter(data=img, label=None, batch_size=1) res = self.mod.predict(eval_data=img, num_batch=1) res = res.asnumpy() prediction = CtcMetrics.ctc_label(np.argmax(res, axis=-1).tolist()) prediction = [p - 1 for p in prediction] return prediction
def test(val_data, ctx): metric = CtcMetrics(num_classes=config.num_classes) metric.reset() for datas, labels in val_data: data = gluon.utils.split_and_load(nd.array(datas), ctx_list=ctx, batch_axis=0, even_split=False) label = gluon.utils.split_and_load(nd.array(labels), ctx_list=ctx, batch_axis=0, even_split=False) output = [net(X) for X in data] metric.update(label, output) return metric.get()
def train(ctx, batch_size): #net.initialize(mx.init.Xavier(), ctx=ctx) train_data = DataLoader(ImageDataset(root=default.dataset_path, train=True), \ batch_size=batch_size,shuffle=True,num_workers=num_workers) val_data = DataLoader(ImageDataset(root=default.dataset_path, train=False), \ batch_size=batch_size, shuffle=True,num_workers=num_workers) # lr_epoch = [int(epoch) for epoch in args.lr_step.split(',')] net.collect_params().reset_ctx(ctx) lr = args.lr end_lr = args.end_lr lr_decay = args.lr_decay lr_decay_step = args.lr_decay_step all_step = len(train_data) schedule = mx.lr_scheduler.FactorScheduler(step=lr_decay_step * all_step, factor=lr_decay, stop_factor_lr=end_lr) adam_optimizer = mx.optimizer.Adam(learning_rate=lr, lr_scheduler=schedule) trainer = gluon.Trainer(net.collect_params(), optimizer=adam_optimizer) train_metric = CtcMetrics() train_history = TrainingHistory(['training-error', 'validation-error']) iteration = 0 best_val_score = 0 save_period = args.save_period save_dir = args.save_dir model_name = args.prefix plot_path = args.save_dir epochs = args.end_epoch frequent = args.frequent for epoch in range(epochs): tic = time.time() train_metric.reset() train_loss = 0 num_batch = 0 tic_b = time.time() for datas, labels in train_data: data = gluon.utils.split_and_load(nd.array(datas), ctx_list=ctx, batch_axis=0, even_split=False) label = gluon.utils.split_and_load(nd.array(labels), ctx_list=ctx, batch_axis=0, even_split=False) with ag.record(): output = [net(X) for X in data] loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)] for l in loss: l.backward() trainer.step(batch_size) train_loss += sum([l.sum().asscalar() for l in loss]) train_metric.update(label, output) name, acc = train_metric.get() iteration += 1 num_batch += 1 if num_batch % frequent == 0: train_loss_b = train_loss / (batch_size * num_batch) logging.info( '[Epoch %d] [num_bath %d] tain_acc=%f loss=%f time/batch: %f' % (epoch, num_batch, acc, train_loss_b, (time.time() - tic_b) / num_batch)) train_loss /= batch_size * num_batch name, acc = train_metric.get() name, val_acc = test(val_data, ctx) train_history.update([1 - acc, 1 - val_acc]) train_history.plot(save_path='%s/%s_history.png' % (plot_path, model_name)) if val_acc > best_val_score: best_val_score = val_acc net.save_parameters( '%s/%.4f-crnn-%s-%d-best.params' % (save_dir, best_val_score, model_name, epoch)) logging.info('[Epoch %d] train=%f val=%f loss=%f time: %f' % (epoch, acc, val_acc, train_loss, time.time() - tic)) if save_period and save_dir and (epoch + 1) % save_period == 0: symbol_file = os.path.join(save_dir, model_name) net.export(path=symbol_file, epoch=epoch) # net.save_parameters('%s/crnn-%s-%d.params' % (save_dir, model_name, epoch)) if save_period and save_dir: symbol_file = os.path.join(save_dir, model_name) net.export(path=symbol_file, epoch=epoch - 1)
def main(): args = parse_args() # ctx = [] # cvd = os.environ['CUDA_VISIBLE_DEVICES'].strip() # if len(cvd)>0: # for i in xrange(len(cvd.split(','))): # ctx.append(mx.gpu(i)) # if len(ctx)==0: # ctx = [mx.cpu()] # print('use cpu') # else: # print('gpu num:', len(ctx)) # ctx = [mx.gpu(0),mx.gpu(2),mx.gpu(4),mx.gpu(6)] ctx = [mx.gpu(0)] args.ctx_num = len(ctx) args.per_batch_size = args.batch_size // args.ctx_num # data_names = ['data'] + [x[0] for x in init_states] if config.use_lstm: init_c = [('l%d_init_c' % l, (args.batch_size, config.num_hidden)) for l in range(config.num_lstm_layer * 2)] init_h = [('l%d_init_h' % l, (args.batch_size, config.num_hidden)) for l in range(config.num_lstm_layer * 2)] init_states = init_c + init_h # data_names = ['data'] + [x[0] for x in init_states] train_iter = TextIter(dataset_path=args.dataset_path, image_path=config.image_path, image_set='train', batch_size=args.batch_size, init_states=init_states) val_iter = TextIter(dataset_path=args.dataset_path, image_path=config.image_path, image_set='test', batch_size=args.batch_size, init_states=init_states) # sym = crnn_lstm(args.network, args.per_batch_size) # else:# # data_names = ['data'] # train_iter = TextIter(path=args.dataset_path, data_root=config.image_path, batch_size=args.batch_size, # num_label=100,init_states=init_states) # val_iter = TextIter(path=args.dataset_path, data_root=config.image_path, batch_size=args.batch_size, # num_label=100,init_states=init_states) # sym = crnn_no_lstm(args.network, args.per_batch_size) # head = '%(asctime)-15s %(message)s' # logging.basicConfig(level=logging.DEBUG, format=head) metrics = CtcMetrics() # if args.network[0] == 'r' or args.network[0] == 'y': # initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="out", magnitude=2) # resnet style # elif args.network[0] == 'i' or args.network[0] == 'x': # initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2) # inception # else: # initializer = mx.init.Xavier(rnd_type='uniform', factor_type="in", magnitude=2) initializer = mx.init.Xavier(factor_type="in", magnitude=2.34) _rescale = 1.0 / args.ctx_num base_lr = args.lr lr_factor = 0.5 lr_epoch = [int(epoch) for epoch in args.lr_step.split(',')] lr_epoch_diff = [ epoch - args.begin_epoch for epoch in lr_epoch if epoch > args.begin_epoch ] lr = base_lr * (lr_factor**(len(lr_epoch) - len(lr_epoch_diff))) lr_iters = [ int(epoch * train_iter.num_samples() / args.batch_size) for epoch in lr_epoch_diff ] logger.info('lr %f lr_epoch_diff %s lr_iters %s' % (lr, lr_epoch_diff, lr_iters)) lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(lr_iters, lr_factor) if config.use_lstm: optimizer = 'AdaDelta' optimizer_params = { 'wd': 0.00001, 'learning_rate': base_lr, 'lr_scheduler': lr_scheduler, 'rescale_grad': (1.0 / args.ctx_num), 'clip_gradient': None } else: optimizer = 'sgd' optimizer_params = { 'momentum': 0.9, 'wd': 0.0002, 'learning_rate': base_lr, 'lr_scheduler': lr_scheduler, 'rescale_grad': (1.0 / args.ctx_num), 'clip_gradient': None } if args.pretrained: sym, arg_params, aux_params = mx.model.load_checkpoint( args.pretrained, args.pretrained_epoch) else: arg_params = None aux_params = None if config.use_lstm: module = mx.mod.BucketingModule( sym_gen=crnn_lstm, default_bucket_key=train_iter.default_bucket_key, context=ctx) else: module = mx.mod.BucketingModule( sym_gen=crnn_no_lstm, default_bucket_key=train_iter.default_bucket_key, context=ctx) module.fit( train_data=train_iter, eval_data=val_iter, begin_epoch=args.begin_epoch, num_epoch=args.end_epoch, # allow_missing=True, # use metrics.accuracy or metrics.accuracy_lcs eval_metric=mx.metric.np(metrics.accuracy, allow_extra_outputs=True), optimizer=optimizer, optimizer_params=optimizer_params, initializer=initializer, arg_params=arg_params, aux_params=aux_params, batch_end_callback=mx.callback.Speedometer(args.batch_size, args.frequent), epoch_end_callback=mx.callback.do_checkpoint(args.prefix, period=10), )