Example #1
0
def main():

    global lr

    if cfg.dataset == 'total-text':

        trainset = TotalText(data_root='data/total-text',
                             ignore_list=None,
                             is_training=True,
                             transform=Augmentation(size=cfg.input_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))

        valset = TotalText(data_root='data/total-text',
                           ignore_list=None,
                           is_training=False,
                           transform=BaseTransform(size=cfg.input_size,
                                                   mean=cfg.means,
                                                   std=cfg.stds))
    else:
        pass

    train_loader = data.DataLoader(trainset,
                                   batch_size=cfg.batch_size,
                                   shuffle=True,
                                   num_workers=cfg.num_workers)
    val_loader = data.DataLoader(valset,
                                 batch_size=cfg.batch_size,
                                 shuffle=False,
                                 num_workers=cfg.num_workers)

    log_dir = os.path.join(
        cfg.log_dir,
        datetime.now().strftime('%b%d_%H-%M-%S_') + cfg.exp_name)
    mkdirs(log_dir)
    logger = LogSummary(log_dir)

    # Model
    model = TextNet()
    if cfg.mgpu:
        model = nn.DataParallel(model, device_ids=cfg.gpu_ids)

    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True

    criterion = TextLoss()
    lr = cfg.lr
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.94)

    print('Start training TextSnake.')

    for epoch in range(cfg.start_epoch, cfg.max_epoch):
        train(model, train_loader, criterion, scheduler, optimizer, epoch,
              logger)
        with torch.no_grad():
            validation(model, val_loader, criterion, epoch, logger)

    print('End.')
Example #2
0
def main():
    global total_iter
    data_root = os.path.join('/home/shf/fudan_ocr_system/datasets/',
                             cfg.dataset)

    trainset = TotalText(data_root=data_root,
                         ignore_list=os.path.join(data_root,
                                                  'ignore_list.txt'),
                         is_training=True,
                         transform=NewAugmentation(size=cfg.input_size,
                                                   mean=cfg.means,
                                                   std=cfg.stds,
                                                   maxlen=1280,
                                                   minlen=512))

    train_loader = data.DataLoader(trainset,
                                   batch_size=cfg.batch_size,
                                   shuffle=True,
                                   num_workers=cfg.num_workers)

    # Model
    model = TextNet(backbone=cfg.backbone, output_channel=7)
    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True

    criterion = TextLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
    # scheduler = lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.94)
    # if cfg.dataset == 'ArT_train':
    #     scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[10000, 50000], gamma=0.1)
    # elif cfg.dataset == 'LSVT_full_train':
    #     scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[10000, 50000], gamma=0.1)

    # load model if resume
    if cfg.resume is not None:
        global total_iter
        checkpoint = load_model(cfg.resume)
        cfg.start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optim'])
        # scheduler.load_state_dict(checkpoint['scheduler'])
        total_iter = checkpoint['epoch'] * len(train_loader)

    if not os.path.exists(os.path.join(cfg.summary_dir, cfg.exp_name)):
        os.mkdir(os.path.join(cfg.summary_dir, cfg.exp_name))
    summary_writer = SummaryWriter(
        log_dir=os.path.join(cfg.summary_dir, cfg.exp_name))

    print('Start training TextSnake.')

    for epoch in range(cfg.start_epoch, cfg.max_epoch):
        adjust_learning_rate(optimizer, total_iter)
        train(model, train_loader, criterion, None, optimizer, epoch,
              summary_writer)

    print('End.')
Example #3
0
def main():

    if cfg.dataset == 'total-text':

        trainset = TotalText(
            data_root='data/total-text',
            ignore_list='./dataset/total_text/ignore_list.txt',
            is_training=True,
            transform=Augmentation(size=cfg.input_size,
                                   mean=cfg.means,
                                   std=cfg.stds))

        valset = TotalText(data_root='data/total-text',
                           ignore_list=None,
                           is_training=False,
                           transform=BaseTransform(size=cfg.input_size,
                                                   mean=cfg.means,
                                                   std=cfg.stds))
    else:
        pass

    train_loader = data.DataLoader(trainset,
                                   batch_size=cfg.batch_size,
                                   shuffle=True,
                                   num_workers=cfg.num_workers)
    val_loader = data.DataLoader(valset,
                                 batch_size=cfg.batch_size,
                                 shuffle=False,
                                 num_workers=cfg.num_workers)

    # Model
    if cfg.mgpu:
        model = TextNet()
    # model = nn.DataParallel(model, device_ids=cfg.gpu_ids)

    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True

    criterion = TextLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.94)

    print('Start training TextSnake.')

    for epoch in range(cfg.start_epoch, cfg.max_epoch):
        train(model, train_loader, criterion, scheduler, optimizer, epoch)
        validation(model, val_loader, criterion)

    print('End.')
Example #4
0
def main():

    global lr

    if cfg.dataset == 'total-text':

        trainset = TotalText(data_root='data/total-text',
                             ignore_list=None,
                             is_training=True,
                             transform=Augmentation(size=cfg.input_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))

        valset = TotalText(data_root='data/total-text',
                           ignore_list=None,
                           is_training=False,
                           transform=BaseTransform(size=cfg.input_size,
                                                   mean=cfg.means,
                                                   std=cfg.stds))

    elif cfg.dataset == 'synth-text':
        trainset = SynthText(data_root='data/SynthText',
                             is_training=True,
                             transform=Augmentation(size=cfg.input_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))
        valset = None
    else:
        pass

    train_loader = data.DataLoader(trainset,
                                   batch_size=cfg.batch_size,
                                   shuffle=True,
                                   num_workers=cfg.num_workers)
    # DataLoader时,设置pin_memory = True,则意味着生成的Tensor数据最开始是属于内存中的锁页内存,这样将内存的Tensor转义到GPU的显存就会更快一些。
    if valset:
        val_loader = data.DataLoader(valset,
                                     batch_size=cfg.batch_size,
                                     shuffle=False,
                                     num_workers=cfg.num_workers)
    else:
        valset = None

    log_dir = os.path.join(
        cfg.log_dir,
        datetime.now().strftime('%b%d_%H-%M-%S_') + cfg.exp_name)
    logger = LogSummary(log_dir)

    # Model
    # 载入模型,backbone默认为vgg
    model = TextNet(is_training=True, backbone=cfg.net)
    if cfg.mgpu:
        model = nn.DataParallel(model)

    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True
    # 载入继续训练
    if cfg.resume:
        load_model(model, cfg.resume)

    # loss定义主要有loss_tr, loss_tcl, loss_radii, loss_sin, loss_cos
    criterion = TextLoss()
    lr = cfg.lr
    # Adam梯度优化
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)

    if cfg.dataset == 'synth-text':
        # 固定学习率。当last_epoch = -1时,将初始lr设置为lr。
        scheduler = FixLR(optimizer)
    else:
        # 动态学习率.每step_size步, lr*0.1
        scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)

    print('Start training TextSnake.')
    for epoch in range(cfg.start_epoch, cfg.max_epoch):
        train(model, train_loader, criterion, scheduler, optimizer, epoch,
              logger)
        if valset:
            validation(model, val_loader, criterion, epoch, logger)

    print('End.')
Example #5
0
    datetime.now().strftime('%b%d_%H-%M-%S_') + cfg.exp_name)
logger = LogSummary(log_dir)

model = TextNet(is_training=True, backbone=cfg.net)
if cfg.mgpu:
    model = nn.DataParallel(model)

model = model.to(cfg.device)

if cfg.cuda:
    cudnn.benchmark = True

if cfg.resume:
    load_model(model, cfg.resume)

criterion = TextLoss()
lr = cfg.lr
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)

scheduler = lr_scheduler.StepLR(optimizer,
                                step_size=cfg.step_size,
                                gamma=cfg.gamma)

print('Start training TextSnake.')
for epoch in range(cfg.start_epoch, cfg.max_epoch):
    train_step = train(model, train_loader, criterion, scheduler, optimizer,
                       epoch, logger, train_step)
    if valset:
        validation(model, val_loader, criterion, epoch, logger)
    save_model(model, "final", scheduler.get_lr(), optimizer)
Example #6
0
def main():

    # global lr

    trainset = CustomText(data_root='data/{}'.format(cfg.dataset),
                          is_training=True,
                          transform=Augmentation(size=cfg.input_size,
                                                 mean=cfg.means,
                                                 std=cfg.stds))

    valset = CustomText(data_root='data/{}'.format(cfg.dataset),
                        is_training=False,
                        transform=BaseTransform(size=cfg.input_size,
                                                mean=cfg.means,
                                                std=cfg.stds))

    train_loader = data.DataLoader(trainset,
                                   batch_size=cfg.batch_size,
                                   shuffle=True,
                                   num_workers=cfg.num_workers,
                                   collate_fn=my_collate)
    if valset:
        val_loader = data.DataLoader(valset,
                                     batch_size=cfg.batch_size,
                                     shuffle=False,
                                     num_workers=cfg.num_workers,
                                     collate_fn=my_collate)
    else:
        valset = None

    log_dir = os.path.join(
        cfg.log_dir,
        datetime.now().strftime('%b%d_%H-%M-%S_') + cfg.exp_name)
    logger = LogSummary(log_dir)

    # Model
    model = TextNet(is_training=True, backbone=cfg.net)
    if cfg.mgpu:
        model = nn.DataParallel(model)

    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True

    if cfg.resume:
        load_model(model, cfg.resume)

    criterion = TextLoss()
    # lr = cfg.lr
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)

    if cfg.dataset == 'synth-text':
        scheduler = FixLR(optimizer)
    else:
        scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)

    print('Start training TextSnake.')
    writer = SummaryWriter(logdir='process')
    for epoch in range(cfg.start_epoch, cfg.max_epoch):

        train(model, train_loader, criterion, scheduler, optimizer, epoch,
              logger)
        if valset:
            validation_loss = validation(model, val_loader, criterion, epoch,
                                         logger)

            writer.add_scalar('Validation Loss', validation_loss, epoch)

    print('End.')
Example #7
0
    def Train(self):
        cfg.max_epoch = self.system_dict["params"]["max_epoch"]
        cfg.means = self.system_dict["params"]["means"]
        cfg.stds = self.system_dict["params"]["stds"]
        cfg.log_dir = self.system_dict["params"]["log_dir"]
        cfg.exp_name = self.system_dict["params"]["exp_name"]
        cfg.net = self.system_dict["params"]["net"]
        cfg.resume = self.system_dict["params"]["resume"]
        cfg.cuda = self.system_dict["params"]["cuda"]
        cfg.mgpu = self.system_dict["params"]["mgpu"]
        cfg.save_dir = self.system_dict["params"]["save_dir"]
        cfg.vis_dir = self.system_dict["params"]["vis_dir"]
        cfg.input_channel = self.system_dict["params"]["input_channel"]

        cfg.lr = self.system_dict["params"]["lr"]
        cfg.weight_decay = self.system_dict["params"]["weight_decay"]
        cfg.gamma = self.system_dict["params"]["gamma"]
        cfg.momentum = self.system_dict["params"]["momentum"]
        cfg.optim = self.system_dict["params"]["optim"]
        cfg.display_freq = self.system_dict["params"]["display_freq"]
        cfg.viz_freq = self.system_dict["params"]["viz_freq"]
        cfg.save_freq = self.system_dict["params"]["save_freq"]
        cfg.log_freq = self.system_dict["params"]["log_freq"]

        cfg.batch_size = self.system_dict["params"]["batch_size"]
        cfg.rescale = self.system_dict["params"]["rescale"]
        cfg.checkepoch = self.system_dict["params"]["checkepoch"]

        cfg.val_freq = 1000
        cfg.start_iter = 0
        cfg.loss = "CrossEntropyLoss"
        cfg.pretrain = False
        cfg.verbose = True
        cfg.viz = True
        cfg.lr_adjust = "step"  #fix, step
        cfg.stepvalues = []
        cfg.step_size = cfg.max_epoch // 2

        if cfg.cuda and torch.cuda.is_available():
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
            cudnn.benchmark = True
            cfg.device = torch.device("cuda")
        else:
            torch.set_default_tensor_type('torch.FloatTensor')
            cfg.device = torch.device("cpu")

        # Create weights saving directory
        if not os.path.exists(cfg.save_dir):
            os.mkdir(cfg.save_dir)

        # Create weights saving directory of target model
        model_save_path = os.path.join(cfg.save_dir, cfg.exp_name)

        if not os.path.exists(model_save_path):
            os.mkdir(model_save_path)

        if not os.path.exists(cfg.vis_dir):
            os.mkdir(cfg.vis_dir)

        if (self.system_dict["params"]["annotation_type"] == "text"):
            trainset = TotalText_txt(
                self.system_dict["params"]["train_img_folder"],
                self.system_dict["params"]["train_anno_folder"],
                ignore_list=None,
                is_training=True,
                transform=Augmentation(size=cfg.input_size,
                                       mean=cfg.means,
                                       std=cfg.stds))
            train_loader = data.DataLoader(trainset,
                                           batch_size=cfg.batch_size,
                                           shuffle=True,
                                           num_workers=cfg.num_workers)

            if (self.system_dict["params"]["val_dataset"]):
                valset = TotalText_txt(
                    self.system_dict["params"]["val_img_folder"],
                    self.system_dict["params"]["val_anno_folder"],
                    ignore_list=None,
                    is_training=False,
                    transform=BaseTransform(size=cfg.input_size,
                                            mean=cfg.means,
                                            std=cfg.stds))
                val_loader = data.DataLoader(valset,
                                             batch_size=cfg.batch_size,
                                             shuffle=False,
                                             num_workers=cfg.num_workers)
            else:
                valset = None

        elif (self.system_dict["params"]["annotation_type"] == "mat"):
            trainset = TotalText_mat(
                self.system_dict["params"]["train_img_folder"],
                self.system_dict["params"]["train_anno_folder"],
                ignore_list=None,
                is_training=True,
                transform=Augmentation(size=cfg.input_size,
                                       mean=cfg.means,
                                       std=cfg.stds))
            train_loader = data.DataLoader(trainset,
                                           batch_size=cfg.batch_size,
                                           shuffle=True,
                                           num_workers=cfg.num_workers)

            if (self.system_dict["params"]["val_dataset"]):
                valset = TotalText_mat(
                    self.system_dict["params"]["val_img_folder"],
                    self.system_dict["params"]["val_anno_folder"],
                    ignore_list=None,
                    is_training=False,
                    transform=BaseTransform(size=cfg.input_size,
                                            mean=cfg.means,
                                            std=cfg.stds))
                val_loader = data.DataLoader(valset,
                                             batch_size=cfg.batch_size,
                                             shuffle=False,
                                             num_workers=cfg.num_workers)
            else:
                valset = None

        log_dir = os.path.join(
            cfg.log_dir,
            datetime.now().strftime('%b%d_%H-%M-%S_') + cfg.exp_name)
        logger = LogSummary(log_dir)

        model = TextNet(is_training=True, backbone=cfg.net)
        if cfg.mgpu:
            model = nn.DataParallel(model)

        model = model.to(cfg.device)

        if cfg.cuda:
            cudnn.benchmark = True

        if cfg.resume:
            self.load_model(model, cfg.resume)

        criterion = TextLoss()
        lr = cfg.lr
        if (cfg.optim == "adam"):
            optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
        else:
            optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)

        scheduler = lr_scheduler.StepLR(optimizer,
                                        step_size=cfg.step_size,
                                        gamma=cfg.gamma)

        train_step = 0
        print('Start training TextSnake.')
        for epoch in range(cfg.start_epoch, cfg.max_epoch):
            train_step = self.train(model, train_loader, criterion, scheduler,
                                    optimizer, epoch, logger, train_step)
            if valset:
                self.validation(model, val_loader, criterion, epoch, logger)
            self.save_model(model, "final", scheduler.get_lr(), optimizer)
Example #8
0
def main():

    global lr
    if cfg.exp_name == 'Totaltext':
        trainset = TotalText(data_root='data/total-text-mat',
                             ignore_list=None,
                             is_training=True,
                             transform=Augmentation(size=cfg.input_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))
        # valset = TotalText(
        #     data_root='data/total-text-mat',
        #     ignore_list=None,
        #     is_training=False,
        #     transform=BaseTransform(size=cfg.input_size, mean=cfg.means, std=cfg.stds)
        # )
        valset = None

    elif cfg.exp_name == 'Synthtext':
        trainset = SynthText(data_root='data/SynthText',
                             is_training=True,
                             transform=Augmentation(size=cfg.input_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))
        valset = None

    elif cfg.exp_name == 'Ctw1500':
        trainset = Ctw1500Text(data_root='data/ctw1500',
                               is_training=True,
                               transform=Augmentation(size=cfg.input_size,
                                                      mean=cfg.means,
                                                      std=cfg.stds))
        valset = None

    elif cfg.exp_name == 'Icdar2015':
        trainset = Icdar15Text(data_root='data/Icdar2015',
                               is_training=True,
                               transform=Augmentation(size=cfg.input_size,
                                                      mean=cfg.means,
                                                      std=cfg.stds))
        valset = None
    elif cfg.exp_name == 'MLT2017':
        trainset = Mlt2017Text(data_root='data/MLT2017',
                               is_training=True,
                               transform=Augmentation(size=cfg.input_size,
                                                      mean=cfg.means,
                                                      std=cfg.stds))
        valset = None

    elif cfg.exp_name == 'TD500':
        trainset = TD500Text(data_root='data/TD500',
                             is_training=True,
                             transform=Augmentation(size=cfg.input_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))
        valset = None

    else:
        print("dataset name is not correct")

    train_loader = data.DataLoader(trainset,
                                   batch_size=cfg.batch_size,
                                   shuffle=True,
                                   num_workers=cfg.num_workers,
                                   pin_memory=True)

    log_dir = os.path.join(
        cfg.log_dir,
        datetime.now().strftime('%b%d_%H-%M-%S_') + cfg.exp_name)
    logger = LogSummary(log_dir)

    # Model
    model = TextNet(backbone=cfg.net, is_training=True)
    if cfg.mgpu:
        model = nn.DataParallel(model)

    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True

    if cfg.resume:
        load_model(model, cfg.resume)

    criterion = TextLoss()

    lr = cfg.lr
    moment = cfg.momentum
    if cfg.optim == "Adam" or cfg.exp_name == 'Synthtext':
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=moment)

    if cfg.exp_name == 'Synthtext':
        scheduler = FixLR(optimizer)
    else:
        scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.90)

    print('Start training TextGraph.')
    for epoch in range(cfg.start_epoch, cfg.start_epoch + cfg.max_epoch + 1):
        scheduler.step()
        train(model, train_loader, criterion, scheduler, optimizer, epoch,
              logger)

    print('End.')

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
Example #9
0
def train_TextSnake(config_file):

    import sys
    sys.path.append('./detection_model/TextSnake_pytorch')

    import os
    import time

    import torch
    import torch.backends.cudnn as cudnn
    import torch.utils.data as data
    from torch.optim import lr_scheduler
    import torchvision.utils as vutils
    from tensorboardX import SummaryWriter

    from dataset.total_text import TotalText
    from network.loss import TextLoss
    from network.textnet import TextNet
    from util.augmentation import EvalTransform, NewAugmentation
    from util.config import config as cfg, update_config, print_config, init_config
    from util.misc import AverageMeter
    from util.misc import mkdirs, to_device
    from util.option import BaseOptions
    from util.visualize import visualize_network_output

    from yacs.config import CfgNode as CN

    global total_iter
    total_iter = 0

    def read_config_file(config_file):
        """
        read config information form yaml file
        """
        f = open(config_file)
        opt = CN.load_cfg(f)
        return opt

    opt = read_config_file(config_file)

    def adjust_learning_rate(optimizer, i):
        if 0 <= i * opt.batch_size < 100000:
            lr = opt.lr
        elif 100000 <= i * opt.batch_size < 400000:
            lr = opt.lr * 0.1
        else:
            lr = opt.lr * 0.1 * (0.94**(
                (i * opt.batch_size - 300000) // 100000))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    def adjust_requires_grad(model, i):
        if 0 <= i < 4000:
            for name, param in model.named_parameters():
                if name == 'conv1.0.weight' or name == 'conv1.0.bias' or \
                   name == 'conv1.1.weight' or name == 'conv1.1.bias':
                    param.requires_grad = False
        else:
            for name, param in model.named_parameters():
                if name == 'conv1.0.weight' or name == 'conv1.0.bias' or \
                   name == 'conv1.1.weight' or name == 'conv1.1.bias':
                    param.requires_grad = True

    def save_model(model, optimizer, scheduler, epoch):
        save_dir = os.path.join(opt.save_dir, opt.exp_name)
        if not os.path.exists(save_dir):
            mkdirs(save_dir)

        save_path = os.path.join(
            save_dir, 'textsnake_{}_{}.pth'.format(model.backbone_name, epoch))
        print('Saving to {}.'.format(save_path))
        state_dict = {
            'epoch': epoch,
            'model': model.state_dict(),
            'optim': optimizer.state_dict()
            # 'scheduler': scheduler.state_dict()
        }
        torch.save(state_dict, save_path)

    def load_model(save_path):
        print('Loading from {}.'.format(save_path))
        checkpoint = torch.load(save_path)
        return checkpoint

    def train(model, train_loader, criterion, scheduler, optimizer, epoch,
              summary_writer):

        start = time.time()
        losses = AverageMeter()
        batch_time = AverageMeter()
        data_time = AverageMeter()
        end = time.time()
        model.train()
        global total_iter

        for i, (img, train_mask, tr_mask, tcl_mask, radius_map, sin_map,
                cos_map, meta) in enumerate(train_loader):
            data_time.update(time.time() - end)

            img, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map = to_device(
                img, train_mask, tr_mask, tcl_mask, radius_map, sin_map,
                cos_map)

            output = model(img)
            tr_loss, tcl_loss, sin_loss, cos_loss, radii_loss = \
                criterion(output, tr_mask, tcl_mask, sin_map, cos_map, radius_map, train_mask, total_iter)
            loss = tr_loss + tcl_loss + sin_loss + cos_loss + radii_loss

            # backward
            # scheduler.step()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            losses.update(loss.item())

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if opt.viz and i < opt.vis_num:
                visualize_network_output(output,
                                         tr_mask,
                                         tcl_mask,
                                         prefix='train_{}'.format(i))

            if i % opt.display_freq == 0:
                print(
                    'Epoch: [ {} ][ {:03d} / {:03d} ] - Loss: {:.4f} - tr_loss: {:.4f} - tcl_loss: {:.4f} - sin_loss: {:.4f} - cos_loss: {:.4f} - radii_loss: {:.4f} - {:.2f}s/step'
                    .format(epoch, i, len(train_loader), loss.item(),
                            tr_loss.item(), tcl_loss.item(), sin_loss.item(),
                            cos_loss.item(), radii_loss.item(),
                            batch_time.avg))

            # write summary
            if total_iter % opt.summary_freq == 0:
                print('Summary in {}'.format(
                    os.path.join(opt.summary_dir, opt.exp_name)))
                tr_pred = output[:, 0:2].softmax(dim=1)[:, 1:2]
                tcl_pred = output[:, 2:4].softmax(dim=1)[:, 1:2]
                summary_writer.add_image('input_image',
                                         vutils.make_grid(img, normalize=True),
                                         total_iter)
                summary_writer.add_image(
                    'tr/tr_pred',
                    vutils.make_grid(tr_pred * 255, normalize=True),
                    total_iter)
                summary_writer.add_image(
                    'tr/tr_mask',
                    vutils.make_grid(
                        torch.unsqueeze(tr_mask * train_mask, 1) * 255),
                    total_iter)
                summary_writer.add_image(
                    'tcl/tcl_pred',
                    vutils.make_grid(tcl_pred * 255, normalize=True),
                    total_iter)
                summary_writer.add_image(
                    'tcl/tcl_mask',
                    vutils.make_grid(
                        torch.unsqueeze(tcl_mask * train_mask, 1) * 255),
                    total_iter)
                summary_writer.add_scalar('learning_rate',
                                          optimizer.param_groups[0]['lr'],
                                          total_iter)
                summary_writer.add_scalar('model/tr_loss', tr_loss.item(),
                                          total_iter)
                summary_writer.add_scalar('model/tcl_loss', tcl_loss.item(),
                                          total_iter)
                summary_writer.add_scalar('model/sin_loss', sin_loss.item(),
                                          total_iter)
                summary_writer.add_scalar('model/cos_loss', cos_loss.item(),
                                          total_iter)
                summary_writer.add_scalar('model/radii_loss',
                                          radii_loss.item(), total_iter)
                summary_writer.add_scalar('model/loss', loss.item(),
                                          total_iter)

            total_iter += 1

        print('Speed: {}s /step, {}s /epoch'.format(batch_time.avg,
                                                    time.time() - start))

        if epoch % opt.save_freq == 0:
            save_model(model, optimizer, scheduler, epoch)

        print('Training Loss: {}'.format(losses.avg))

    def validation(model, valid_loader, criterion):
        """
        print a series of loss information
        """

        model.eval()
        losses = AverageMeter()

        for i, (img, train_mask, tr_mask, tcl_mask, radius_map, sin_map,
                cos_map, meta) in enumerate(valid_loader):
            print(meta['image_id'])
            img, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map = to_device(
                img, train_mask, tr_mask, tcl_mask, radius_map, sin_map,
                cos_map)

            output = model(img)

            tr_loss, tcl_loss, sin_loss, cos_loss, radii_loss = \
                criterion(output, tr_mask, tcl_mask, sin_map, cos_map, radius_map, train_mask)
            loss = tr_loss + tcl_loss + sin_loss + cos_loss + radii_loss
            losses.update(loss.item())

            if opt.viz and i < opt.vis_num:
                visualize_network_output(output,
                                         tr_mask,
                                         tcl_mask,
                                         prefix='val_{}'.format(i))

            if i % opt.display_freq == 0:
                print(
                    'Validation: - Loss: {:.4f} - tr_loss: {:.4f} - tcl_loss: {:.4f} - sin_loss: {:.4f} - cos_loss: {:.4f} - radii_loss: {:.4f}'
                    .format(loss.item(), tr_loss.item(), tcl_loss.item(),
                            sin_loss.item(), cos_loss.item(),
                            radii_loss.item()))

        print('Validation Loss: {}'.format(losses.avg))

    # parse arguments

    torch.cuda.set_device(opt.num_device)
    option = BaseOptions(config_file)
    args = option.initialize()

    init_config(opt, config_file)
    update_config(opt, args)
    print_config(opt)

    data_root = os.path.join(opt.data_root, opt.dataset)

    trainset = TotalText(data_root=data_root,
                         ignore_list=os.path.join(data_root,
                                                  'ignore_list.txt'),
                         is_training=True,
                         transform=NewAugmentation(size=opt.input_size,
                                                   mean=opt.means,
                                                   std=opt.stds,
                                                   maxlen=1280,
                                                   minlen=512))

    train_loader = data.DataLoader(trainset,
                                   batch_size=opt.batch_size,
                                   shuffle=True,
                                   num_workers=opt.num_workers)

    # Model
    model = TextNet(backbone=opt.backbone, output_channel=7)
    model = model.to(opt.device)
    if opt.cuda:
        cudnn.benchmark = True

    criterion = TextLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
    # scheduler = lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.94)
    # if opt.dataset == 'ArT_train':
    #     scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[10000, 50000], gamma=0.1)
    # elif opt.dataset == 'LSVT_full_train':
    #     scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[10000, 50000], gamma=0.1)

    # load model if resume
    if opt.resume is not False:
        checkpoint = load_model(opt.resume)
        opt.start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optim'])
        # scheduler.load_state_dict(checkpoint['scheduler'])
        total_iter = checkpoint['epoch'] * len(train_loader)

    if not os.path.exists(os.path.join(opt.summary_dir, opt.exp_name)):
        os.mkdir(os.path.join(opt.summary_dir, opt.exp_name))
    summary_writer = SummaryWriter(
        log_dir=os.path.join(opt.summary_dir, opt.exp_name))

    print('Start training TextSnake.')

    for epoch in range(opt.start_epoch, opt.max_epoch):
        adjust_learning_rate(optimizer, total_iter)
        train(model, train_loader, criterion, None, optimizer, epoch,
              summary_writer)

    print('End.')