Example #1
0
def train(data_dir, pretrain_model, epoches=3, lr=0.001, wd=5e-4,  momentum=0.9, batch_size=5, ctx=mx.cpu(), verbose_step=2, ckpt='ckpt'):

    icdar_loader = ICDAR(data_dir=data_dir)
    loader = DataLoader(icdar_loader, batch_size=batch_size, shuffle=True)
    net = PSENet(num_kernels=7, ctx=ctx)
    # initial params
    net.collect_params().initialize(mx.init.Normal(sigma=0.01), ctx=ctx)
    # net.initialize(ctx=ctx)
    # net.load_parameters(pretrain_model, ctx=ctx, allow_missing=True, ignore_extra=True)
    pse_loss = DiceLoss(lam=0.7)

    cos_shc = ls.PolyScheduler(max_update=icdar_loader.length * epoches//batch_size, base_lr=lr)
    trainer = Trainer(
        net.collect_params(), 
        'sgd', 
        {
            'learning_rate': lr, 
            'wd': wd,
            'momentum': momentum, 
            'lr_scheduler':cos_shc
        })
    summary_writer = SummaryWriter(ckpt)
    for e in range(epoches):
        cumulative_loss = 0

        for i, item in enumerate(loader):
            im, score_maps, kernels, training_masks, ori_img = item
            
            im = im.as_in_context(ctx)
            score_maps = score_maps[:, ::4, ::4].as_in_context(ctx)
            kernels = kernels[:, ::4, ::4, :].as_in_context(ctx)
            training_masks = training_masks[:, ::4, ::4].as_in_context(ctx)

            with autograd.record():
                kernels_pred = net(im)
                
                loss = pse_loss(score_maps, kernels, kernels_pred, training_masks)
                loss.backward()
            trainer.step(batch_size)
            if i%verbose_step==0:
                global_steps = icdar_loader.length * e + i * batch_size
                summary_writer.add_image('score_map', score_maps[0:1, :, :], global_steps)
                summary_writer.add_image('score_map_pred', kernels_pred[0:1, -1, :, :], global_steps)
                summary_writer.add_image('kernel_map', kernels[0:1, :, :, 0], global_steps)
                summary_writer.add_image('kernel_map_pred', kernels_pred[0:1, 0, :, :], global_steps)
                summary_writer.add_scalar('loss', mx.nd.mean(loss).asscalar(), global_steps)
                summary_writer.add_scalar('c_loss', mx.nd.mean(pse_loss.C_loss).asscalar(), global_steps)
                summary_writer.add_scalar('kernel_loss', mx.nd.mean(pse_loss.kernel_loss).asscalar(), global_steps)
                summary_writer.add_scalar('pixel_accuracy', pse_loss.pixel_acc, global_steps)
                print("step: {}, loss: {}, score_loss: {}, kernel_loss: {}, pixel_acc: {}".format(i * batch_size, mx.nd.mean(loss).asscalar(), \
                    mx.nd.mean(pse_loss.C_loss).asscalar(), mx.nd.mean(pse_loss.kernel_loss).asscalar(), \
                        pse_loss.pixel_acc))
            cumulative_loss += mx.nd.mean(loss).asscalar()
        print("Epoch {}, loss: {}".format(e, cumulative_loss))
        net.save_parameters(os.path.join(ckpt, 'model_{}.param'.format(e)))
    summary_writer.close()
epoch_size = \
    int(math.ceil(int(num_training_samples // num_workers) / batch_size))

if args.lr_mode == 'step':
    lr_decay_epoch = [int(i) for i in args.lr_decay_epoch.split(',')]
    steps = [epoch_size * x for x in lr_decay_epoch]
    lr_sched = lr_scheduler.MultiFactorScheduler(
        step=steps,
        factor=args.lr_decay,
        base_lr=(args.lr * num_workers),
        warmup_steps=(args.warmup_epochs * epoch_size),
        warmup_begin_lr=args.warmup_lr)
elif args.lr_mode == 'poly':
    lr_sched = lr_scheduler.PolyScheduler(args.num_epochs * epoch_size,
                                          base_lr=(args.lr * num_workers),
                                          pwr=2,
                                          warmup_steps=(args.warmup_epochs *
                                                        epoch_size),
                                          warmup_begin_lr=args.warmup_lr)
elif args.lr_mode == 'cosine':
    lr_sched = lr_scheduler.CosineScheduler(args.num_epochs * epoch_size,
                                            base_lr=(args.lr * num_workers),
                                            warmup_steps=(args.warmup_epochs *
                                                          epoch_size),
                                            warmup_begin_lr=args.warmup_lr)
else:
    raise ValueError('Invalid lr mode')


# Function for reading data from record file
# For more details about data loading in MXNet, please refer to
# https://mxnet.incubator.apache.org/tutorials/basic/data.html?highlight=imagerecorditer
Example #3
0
def main(train_dir,
         ctx=None,
         lr=0.0001,
         epoches=20,
         batch_size=16,
         checkpoint_path='model',
         debug=False):
    summ_writer = SummaryWriter(checkpoint_path)
    # dataloader
    ctx = eval(ctx)
    context = mx.gpu(ctx) if ctx > 0 else mx.cpu()
    ic_data = text_detection_data(image_dir=train_dir)
    ic_dataloader = DataLoader(dataset=ic_data,
                               batch_size=batch_size,
                               shuffle=True,
                               num_workers=16)
    data_num = len(ic_dataloader) * batch_size
    # model
    east_model = east.EAST(nclass=2, text_scale=1024)
    # east_model = east(text_scale=1024)

    east_model.collect_params().initialize(init=mx.init.Xavier(),
                                           verbose=True,
                                           ctx=context)
    if not debug:
        east_model.hybridize()
    cos_shc = ls.PolyScheduler(max_update=ic_dataloader.length * epoches //
                               batch_size,
                               base_lr=lr)

    trainer = gluon.Trainer(
        east_model.collect_params(), 'sgd', {
            'learning_rate': lr,
            'wd': 1e-5,
            'momentum': 0.9,
            'clip_gradient': 5,
            'lr_scheduler': cos_shc
        })
    EAST_loss = EASTLoss(cls_weight=0.01, iou_weight=1.0, angle_weight=20)
    step = 0
    lr_counter = 0
    lr_steps = [5, 10, 15, 20]
    lr_factor = 0.9

    for epoch in range(epoches):
        loss = []
        if epoch == lr_steps[lr_counter]:
            trainer.set_learning_rate(trainer.learning_rate * lr_factor)
            lr_counter += 1
        for i, batch_data in enumerate(ic_dataloader):
            im, score_map, geo_map, training_mask = map(
                lambda x: x.as_in_context(ctx), batch_data)

            with autograd.record(train_mode=True):

                f_score, f_geo = east_model(im)
                batch_loss = EAST_loss(score_map, f_score, geo_map, f_geo,
                                       training_mask)
                loss.append(batch_loss)
                batch_loss.backward()

            trainer.step(batch_size)
            # if i % 2 == 0:
            step = epoch * data_num + i * batch_size
            model_loss = np.mean(map(lambda x: x.asnumpy()[0], loss))
            summ_writer.add_scalar('model_loss', model_loss[0])
            logging.info("step: {}, loss: {}".format(step,
                                                     batch_loss.asnumpy()))
        ckpt_file = os.path.join(checkpoint_path,
                                 "model_{}.params".format(step))
        east_model.save_parameters(ckpt_file)
        logging.info("save model to {}".format(ckpt_file))