Exemplo n.º 1
0
def main():

    global lr

    if cfg.dataset == 'total-text':

        trainset = TotalText(data_root='data/total-text',
                             ignore_list=None,
                             is_training=True,
                             transform=Augmentation(size=cfg.input_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))

        valset = TotalText(data_root='data/total-text',
                           ignore_list=None,
                           is_training=False,
                           transform=BaseTransform(size=cfg.input_size,
                                                   mean=cfg.means,
                                                   std=cfg.stds))
    else:
        pass

    train_loader = data.DataLoader(trainset,
                                   batch_size=cfg.batch_size,
                                   shuffle=True,
                                   num_workers=cfg.num_workers)
    val_loader = data.DataLoader(valset,
                                 batch_size=cfg.batch_size,
                                 shuffle=False,
                                 num_workers=cfg.num_workers)

    log_dir = os.path.join(
        cfg.log_dir,
        datetime.now().strftime('%b%d_%H-%M-%S_') + cfg.exp_name)
    mkdirs(log_dir)
    logger = LogSummary(log_dir)

    # Model
    model = TextNet()
    if cfg.mgpu:
        model = nn.DataParallel(model, device_ids=cfg.gpu_ids)

    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True

    criterion = TextLoss()
    lr = cfg.lr
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.94)

    print('Start training TextSnake.')

    for epoch in range(cfg.start_epoch, cfg.max_epoch):
        train(model, train_loader, criterion, scheduler, optimizer, epoch,
              logger)
        with torch.no_grad():
            validation(model, val_loader, criterion, epoch, logger)

    print('End.')
Exemplo n.º 2
0
def main():

    global lr

    if cfg.dataset == 'total-text':

        trainset = TotalText(data_root='data/total-text',
                             ignore_list=None,
                             is_training=True,
                             transform=Augmentation(size=cfg.input_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))

        valset = TotalText(data_root='data/total-text',
                           ignore_list=None,
                           is_training=False,
                           transform=BaseTransform(size=cfg.input_size,
                                                   mean=cfg.means,
                                                   std=cfg.stds))

    elif cfg.dataset == 'synth-text':
        trainset = SynthText(data_root='data/SynthText',
                             is_training=True,
                             transform=Augmentation(size=cfg.input_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))
        valset = None
    else:
        pass

    train_loader = data.DataLoader(trainset,
                                   batch_size=cfg.batch_size,
                                   shuffle=True,
                                   num_workers=cfg.num_workers)
    # DataLoader时,设置pin_memory = True,则意味着生成的Tensor数据最开始是属于内存中的锁页内存,这样将内存的Tensor转义到GPU的显存就会更快一些。
    if valset:
        val_loader = data.DataLoader(valset,
                                     batch_size=cfg.batch_size,
                                     shuffle=False,
                                     num_workers=cfg.num_workers)
    else:
        valset = None

    log_dir = os.path.join(
        cfg.log_dir,
        datetime.now().strftime('%b%d_%H-%M-%S_') + cfg.exp_name)
    logger = LogSummary(log_dir)

    # Model
    # 载入模型,backbone默认为vgg
    model = TextNet(is_training=True, backbone=cfg.net)
    if cfg.mgpu:
        model = nn.DataParallel(model)

    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True
    # 载入继续训练
    if cfg.resume:
        load_model(model, cfg.resume)

    # loss定义主要有loss_tr, loss_tcl, loss_radii, loss_sin, loss_cos
    criterion = TextLoss()
    lr = cfg.lr
    # Adam梯度优化
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)

    if cfg.dataset == 'synth-text':
        # 固定学习率。当last_epoch = -1时,将初始lr设置为lr。
        scheduler = FixLR(optimizer)
    else:
        # 动态学习率.每step_size步, lr*0.1
        scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)

    print('Start training TextSnake.')
    for epoch in range(cfg.start_epoch, cfg.max_epoch):
        train(model, train_loader, criterion, scheduler, optimizer, epoch,
              logger)
        if valset:
            validation(model, val_loader, criterion, epoch, logger)

    print('End.')
Exemplo n.º 3
0
                               batch_size=cfg.batch_size,
                               shuffle=True,
                               num_workers=cfg.num_workers)

if valset:
    val_loader = data.DataLoader(valset,
                                 batch_size=cfg.batch_size,
                                 shuffle=False,
                                 num_workers=cfg.num_workers)
else:
    valset = None

log_dir = os.path.join(
    cfg.log_dir,
    datetime.now().strftime('%b%d_%H-%M-%S_') + cfg.exp_name)
logger = LogSummary(log_dir)

model = TextNet(is_training=True, backbone=cfg.net)
if cfg.mgpu:
    model = nn.DataParallel(model)

model = model.to(cfg.device)

if cfg.cuda:
    cudnn.benchmark = True

if cfg.resume:
    load_model(model, cfg.resume)

criterion = TextLoss()
lr = cfg.lr
Exemplo n.º 4
0
def main():

    # global lr

    trainset = CustomText(data_root='data/{}'.format(cfg.dataset),
                          is_training=True,
                          transform=Augmentation(size=cfg.input_size,
                                                 mean=cfg.means,
                                                 std=cfg.stds))

    valset = CustomText(data_root='data/{}'.format(cfg.dataset),
                        is_training=False,
                        transform=BaseTransform(size=cfg.input_size,
                                                mean=cfg.means,
                                                std=cfg.stds))

    train_loader = data.DataLoader(trainset,
                                   batch_size=cfg.batch_size,
                                   shuffle=True,
                                   num_workers=cfg.num_workers,
                                   collate_fn=my_collate)
    if valset:
        val_loader = data.DataLoader(valset,
                                     batch_size=cfg.batch_size,
                                     shuffle=False,
                                     num_workers=cfg.num_workers,
                                     collate_fn=my_collate)
    else:
        valset = None

    log_dir = os.path.join(
        cfg.log_dir,
        datetime.now().strftime('%b%d_%H-%M-%S_') + cfg.exp_name)
    logger = LogSummary(log_dir)

    # Model
    model = TextNet(is_training=True, backbone=cfg.net)
    if cfg.mgpu:
        model = nn.DataParallel(model)

    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True

    if cfg.resume:
        load_model(model, cfg.resume)

    criterion = TextLoss()
    # lr = cfg.lr
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)

    if cfg.dataset == 'synth-text':
        scheduler = FixLR(optimizer)
    else:
        scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)

    print('Start training TextSnake.')
    writer = SummaryWriter(logdir='process')
    for epoch in range(cfg.start_epoch, cfg.max_epoch):

        train(model, train_loader, criterion, scheduler, optimizer, epoch,
              logger)
        if valset:
            validation_loss = validation(model, val_loader, criterion, epoch,
                                         logger)

            writer.add_scalar('Validation Loss', validation_loss, epoch)

    print('End.')
Exemplo n.º 5
0
    def Train(self):
        cfg.max_epoch = self.system_dict["params"]["max_epoch"]
        cfg.means = self.system_dict["params"]["means"]
        cfg.stds = self.system_dict["params"]["stds"]
        cfg.log_dir = self.system_dict["params"]["log_dir"]
        cfg.exp_name = self.system_dict["params"]["exp_name"]
        cfg.net = self.system_dict["params"]["net"]
        cfg.resume = self.system_dict["params"]["resume"]
        cfg.cuda = self.system_dict["params"]["cuda"]
        cfg.mgpu = self.system_dict["params"]["mgpu"]
        cfg.save_dir = self.system_dict["params"]["save_dir"]
        cfg.vis_dir = self.system_dict["params"]["vis_dir"]
        cfg.input_channel = self.system_dict["params"]["input_channel"]

        cfg.lr = self.system_dict["params"]["lr"]
        cfg.weight_decay = self.system_dict["params"]["weight_decay"]
        cfg.gamma = self.system_dict["params"]["gamma"]
        cfg.momentum = self.system_dict["params"]["momentum"]
        cfg.optim = self.system_dict["params"]["optim"]
        cfg.display_freq = self.system_dict["params"]["display_freq"]
        cfg.viz_freq = self.system_dict["params"]["viz_freq"]
        cfg.save_freq = self.system_dict["params"]["save_freq"]
        cfg.log_freq = self.system_dict["params"]["log_freq"]

        cfg.batch_size = self.system_dict["params"]["batch_size"]
        cfg.rescale = self.system_dict["params"]["rescale"]
        cfg.checkepoch = self.system_dict["params"]["checkepoch"]

        cfg.val_freq = 1000
        cfg.start_iter = 0
        cfg.loss = "CrossEntropyLoss"
        cfg.pretrain = False
        cfg.verbose = True
        cfg.viz = True
        cfg.lr_adjust = "step"  #fix, step
        cfg.stepvalues = []
        cfg.step_size = cfg.max_epoch // 2

        if cfg.cuda and torch.cuda.is_available():
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
            cudnn.benchmark = True
            cfg.device = torch.device("cuda")
        else:
            torch.set_default_tensor_type('torch.FloatTensor')
            cfg.device = torch.device("cpu")

        # Create weights saving directory
        if not os.path.exists(cfg.save_dir):
            os.mkdir(cfg.save_dir)

        # Create weights saving directory of target model
        model_save_path = os.path.join(cfg.save_dir, cfg.exp_name)

        if not os.path.exists(model_save_path):
            os.mkdir(model_save_path)

        if not os.path.exists(cfg.vis_dir):
            os.mkdir(cfg.vis_dir)

        if (self.system_dict["params"]["annotation_type"] == "text"):
            trainset = TotalText_txt(
                self.system_dict["params"]["train_img_folder"],
                self.system_dict["params"]["train_anno_folder"],
                ignore_list=None,
                is_training=True,
                transform=Augmentation(size=cfg.input_size,
                                       mean=cfg.means,
                                       std=cfg.stds))
            train_loader = data.DataLoader(trainset,
                                           batch_size=cfg.batch_size,
                                           shuffle=True,
                                           num_workers=cfg.num_workers)

            if (self.system_dict["params"]["val_dataset"]):
                valset = TotalText_txt(
                    self.system_dict["params"]["val_img_folder"],
                    self.system_dict["params"]["val_anno_folder"],
                    ignore_list=None,
                    is_training=False,
                    transform=BaseTransform(size=cfg.input_size,
                                            mean=cfg.means,
                                            std=cfg.stds))
                val_loader = data.DataLoader(valset,
                                             batch_size=cfg.batch_size,
                                             shuffle=False,
                                             num_workers=cfg.num_workers)
            else:
                valset = None

        elif (self.system_dict["params"]["annotation_type"] == "mat"):
            trainset = TotalText_mat(
                self.system_dict["params"]["train_img_folder"],
                self.system_dict["params"]["train_anno_folder"],
                ignore_list=None,
                is_training=True,
                transform=Augmentation(size=cfg.input_size,
                                       mean=cfg.means,
                                       std=cfg.stds))
            train_loader = data.DataLoader(trainset,
                                           batch_size=cfg.batch_size,
                                           shuffle=True,
                                           num_workers=cfg.num_workers)

            if (self.system_dict["params"]["val_dataset"]):
                valset = TotalText_mat(
                    self.system_dict["params"]["val_img_folder"],
                    self.system_dict["params"]["val_anno_folder"],
                    ignore_list=None,
                    is_training=False,
                    transform=BaseTransform(size=cfg.input_size,
                                            mean=cfg.means,
                                            std=cfg.stds))
                val_loader = data.DataLoader(valset,
                                             batch_size=cfg.batch_size,
                                             shuffle=False,
                                             num_workers=cfg.num_workers)
            else:
                valset = None

        log_dir = os.path.join(
            cfg.log_dir,
            datetime.now().strftime('%b%d_%H-%M-%S_') + cfg.exp_name)
        logger = LogSummary(log_dir)

        model = TextNet(is_training=True, backbone=cfg.net)
        if cfg.mgpu:
            model = nn.DataParallel(model)

        model = model.to(cfg.device)

        if cfg.cuda:
            cudnn.benchmark = True

        if cfg.resume:
            self.load_model(model, cfg.resume)

        criterion = TextLoss()
        lr = cfg.lr
        if (cfg.optim == "adam"):
            optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
        else:
            optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)

        scheduler = lr_scheduler.StepLR(optimizer,
                                        step_size=cfg.step_size,
                                        gamma=cfg.gamma)

        train_step = 0
        print('Start training TextSnake.')
        for epoch in range(cfg.start_epoch, cfg.max_epoch):
            train_step = self.train(model, train_loader, criterion, scheduler,
                                    optimizer, epoch, logger, train_step)
            if valset:
                self.validation(model, val_loader, criterion, epoch, logger)
            self.save_model(model, "final", scheduler.get_lr(), optimizer)
Exemplo n.º 6
0
def main():

    global lr
    if cfg.exp_name == 'Totaltext':
        trainset = TotalText(data_root='data/total-text-mat',
                             ignore_list=None,
                             is_training=True,
                             transform=Augmentation(size=cfg.input_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))
        # valset = TotalText(
        #     data_root='data/total-text-mat',
        #     ignore_list=None,
        #     is_training=False,
        #     transform=BaseTransform(size=cfg.input_size, mean=cfg.means, std=cfg.stds)
        # )
        valset = None

    elif cfg.exp_name == 'Synthtext':
        trainset = SynthText(data_root='data/SynthText',
                             is_training=True,
                             transform=Augmentation(size=cfg.input_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))
        valset = None

    elif cfg.exp_name == 'Ctw1500':
        trainset = Ctw1500Text(data_root='data/ctw1500',
                               is_training=True,
                               transform=Augmentation(size=cfg.input_size,
                                                      mean=cfg.means,
                                                      std=cfg.stds))
        valset = None

    elif cfg.exp_name == 'Icdar2015':
        trainset = Icdar15Text(data_root='data/Icdar2015',
                               is_training=True,
                               transform=Augmentation(size=cfg.input_size,
                                                      mean=cfg.means,
                                                      std=cfg.stds))
        valset = None
    elif cfg.exp_name == 'MLT2017':
        trainset = Mlt2017Text(data_root='data/MLT2017',
                               is_training=True,
                               transform=Augmentation(size=cfg.input_size,
                                                      mean=cfg.means,
                                                      std=cfg.stds))
        valset = None

    elif cfg.exp_name == 'TD500':
        trainset = TD500Text(data_root='data/TD500',
                             is_training=True,
                             transform=Augmentation(size=cfg.input_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))
        valset = None

    else:
        print("dataset name is not correct")

    train_loader = data.DataLoader(trainset,
                                   batch_size=cfg.batch_size,
                                   shuffle=True,
                                   num_workers=cfg.num_workers,
                                   pin_memory=True)

    log_dir = os.path.join(
        cfg.log_dir,
        datetime.now().strftime('%b%d_%H-%M-%S_') + cfg.exp_name)
    logger = LogSummary(log_dir)

    # Model
    model = TextNet(backbone=cfg.net, is_training=True)
    if cfg.mgpu:
        model = nn.DataParallel(model)

    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True

    if cfg.resume:
        load_model(model, cfg.resume)

    criterion = TextLoss()

    lr = cfg.lr
    moment = cfg.momentum
    if cfg.optim == "Adam" or cfg.exp_name == 'Synthtext':
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=moment)

    if cfg.exp_name == 'Synthtext':
        scheduler = FixLR(optimizer)
    else:
        scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.90)

    print('Start training TextGraph.')
    for epoch in range(cfg.start_epoch, cfg.start_epoch + cfg.max_epoch + 1):
        scheduler.step()
        train(model, train_loader, criterion, scheduler, optimizer, epoch,
              logger)

    print('End.')

    if torch.cuda.is_available():
        torch.cuda.empty_cache()