Esempio n. 1
0
def main():

    if cfg.dataset_name == "total_text":
        testset = TotalText(
            data_root=cfg.dataset_root,
            ignore_list=None,
            is_training=False,
            transform=BaseTransform(size=cfg.img_size, mean=cfg.means, std=cfg.stds)
        )
    elif cfg.dataset_name == "coco_text":
        testset = COCO_Text(
            data_root=cfg.dataset_root,
            ignore_list=None,
            is_training=False,
            transform=BaseTransform(size=cfg.img_size, mean=cfg.means, std=cfg.stds)
        )
    else:
        testset = None
    test_loader = data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=cfg.testing_num_workers)

    # Model
    model = UHT_Net()
    model_path = cfg.evaluation_model_directory
    load_model(model, model_path)

    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True
    detector = TextFill(model)

    print('Start testing Model.')
    output_dir = os.path.join(cfg.prediction_output_directory, cfg.dataset_name + '_' + cfg.backbone)
    inference(detector, test_loader, output_dir)
Esempio n. 2
0
def main(vis_dir_path):

    osmkdir(vis_dir_path)
    if cfg.exp_name == "Totaltext":
        testset = TotalText(data_root='data/total-text-mat',
                            ignore_list=None,
                            is_training=False,
                            transform=BaseTransform(size=cfg.test_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))

    elif cfg.exp_name == "Ctw1500":
        testset = Ctw1500Text(data_root='data/ctw1500',
                              is_training=False,
                              transform=BaseTransform(size=cfg.test_size,
                                                      mean=cfg.means,
                                                      std=cfg.stds))
    elif cfg.exp_name == "TD500":
        testset = TD500Text(data_root='data/TD500',
                            is_training=False,
                            transform=BaseTransform(size=cfg.test_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))
    else:
        print("{} is not justify".format(cfg.exp_name))

    test_loader = data.DataLoader(testset,
                                  batch_size=1,
                                  shuffle=False,
                                  num_workers=cfg.num_workers)

    # Model
    model = TextNet(is_training=False, backbone=cfg.net)
    model_path = os.path.join(cfg.save_dir, cfg.exp_name,
                              'TextGraph_{}.pth'.format(model.backbone_name))
    model.load_model(model_path)

    # copy to cuda
    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True

    if cfg.graph_link:
        detector = TextDetector_graph(model)

    print('Start testing TextGraph.')
    output_dir = os.path.join(cfg.output_dir, cfg.exp_name)
    inference(detector, test_loader, output_dir)
Esempio n. 3
0
def main():

    testset = DeployDataset(image_root=cfg.img_root,
                            transform=BaseTransform(size=cfg.input_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))
    print(cfg)
    test_loader = data.DataLoader(testset,
                                  batch_size=1,
                                  shuffle=False,
                                  num_workers=cfg.num_workers)

    # Model
    model = TextNet(is_training=False, backbone=cfg.net)
    model_path = os.path.join(cfg.save_dir, cfg.exp_name, \
              'textsnake_{}_{}.pth'.format(model.backbone_name, cfg.checkepoch))
    model.load_model(model_path)

    # copy to cuda
    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True
    detector = TextDetector(model,
                            tr_thresh=cfg.tr_thresh,
                            tcl_thresh=cfg.tcl_thresh)

    print('Start testing TextSnake.')
    output_dir = os.path.join(cfg.output_dir, cfg.exp_name)
    inference(detector, test_loader, output_dir)
Esempio n. 4
0
def main():

    testset = TotalText(
        data_root='/home/shf/fudan_ocr_system/datasets/totaltext',
        ignore_list=None,
        is_training=False,
        transform=BaseTransform(size=cfg.input_size, mean=cfg.means, std=cfg.stds)
    )
    test_loader = data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=cfg.num_workers)

    # Model
    model = TextNet()
    model_path = os.path.join(cfg.save_dir, cfg.exp_name, \
              'textsnake_{}_{}.pth'.format(model.backbone_name, cfg.checkepoch))
    load_model(model, model_path)

    # copy to cuda
    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True
    detector = TextDetector()

    print('Start testing TextSnake.')

    inference(model, detector, test_loader)

    print('End.')
Esempio n. 5
0
def main():

    global lr

    if cfg.dataset == 'total-text':

        trainset = TotalText(data_root='data/total-text',
                             ignore_list=None,
                             is_training=True,
                             transform=Augmentation(size=cfg.input_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))

        valset = TotalText(data_root='data/total-text',
                           ignore_list=None,
                           is_training=False,
                           transform=BaseTransform(size=cfg.input_size,
                                                   mean=cfg.means,
                                                   std=cfg.stds))
    else:
        pass

    train_loader = data.DataLoader(trainset,
                                   batch_size=cfg.batch_size,
                                   shuffle=True,
                                   num_workers=cfg.num_workers)
    val_loader = data.DataLoader(valset,
                                 batch_size=cfg.batch_size,
                                 shuffle=False,
                                 num_workers=cfg.num_workers)

    log_dir = os.path.join(
        cfg.log_dir,
        datetime.now().strftime('%b%d_%H-%M-%S_') + cfg.exp_name)
    mkdirs(log_dir)
    logger = LogSummary(log_dir)

    # Model
    model = TextNet()
    if cfg.mgpu:
        model = nn.DataParallel(model, device_ids=cfg.gpu_ids)

    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True

    criterion = TextLoss()
    lr = cfg.lr
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.94)

    print('Start training TextSnake.')

    for epoch in range(cfg.start_epoch, cfg.max_epoch):
        train(model, train_loader, criterion, scheduler, optimizer, epoch,
              logger)
        with torch.no_grad():
            validation(model, val_loader, criterion, epoch, logger)

    print('End.')
Esempio n. 6
0
def main():

    testset = TotalText(
        data_root='data/total-text',
        ignore_list=None,
        is_training=False,
        transform=BaseTransform(size=cfg.input_size, mean=cfg.means, std=cfg.stds)
    )
    test_loader = data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=cfg.num_workers)

    # Model
    model = TextNet()
    model_path = os.path.join(cfg.save_dir, cfg.exp_name, \
              'textsnake_{}_{}.pth'.format(model.backbone_name, cfg.checkepoch))
    load_model(model, model_path)

    # copy to cuda
    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True
    detector = TextDetector(tr_thresh=cfg.tr_thresh, tcl_thresh=cfg.tcl_thresh)

    print('Start testing TextSnake.')
    output_dir = os.path.join(cfg.output_dir, cfg.exp_name)
    inference(model, detector, test_loader, output_dir)

    # compute DetEval
    print('Computing DetEval in {}/{}'.format(cfg.output_dir, cfg.exp_name))
    subprocess.call(['python', 'dataset/total_text/Evaluation_Protocol/Python_scripts/Deteval.py', args.exp_name])
    print('End.')
Esempio n. 7
0
def main():

    if cfg.dataset == 'total-text':

        trainset = TotalText(
            data_root='data/total-text',
            ignore_list='./dataset/total_text/ignore_list.txt',
            is_training=True,
            transform=Augmentation(size=cfg.input_size,
                                   mean=cfg.means,
                                   std=cfg.stds))

        valset = TotalText(data_root='data/total-text',
                           ignore_list=None,
                           is_training=False,
                           transform=BaseTransform(size=cfg.input_size,
                                                   mean=cfg.means,
                                                   std=cfg.stds))
    else:
        pass

    train_loader = data.DataLoader(trainset,
                                   batch_size=cfg.batch_size,
                                   shuffle=True,
                                   num_workers=cfg.num_workers)
    val_loader = data.DataLoader(valset,
                                 batch_size=cfg.batch_size,
                                 shuffle=False,
                                 num_workers=cfg.num_workers)

    # Model
    if cfg.mgpu:
        model = TextNet()
    # model = nn.DataParallel(model, device_ids=cfg.gpu_ids)

    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True

    criterion = TextLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.94)

    print('Start training TextSnake.')

    for epoch in range(cfg.start_epoch, cfg.max_epoch):
        train(model, train_loader, criterion, scheduler, optimizer, epoch)
        validation(model, val_loader, criterion)

    print('End.')
Esempio n. 8
0
def main():

    global lr

    if cfg.dataset == 'total-text':

        trainset = TotalText(data_root='data/total-text',
                             ignore_list=None,
                             is_training=True,
                             transform=Augmentation(size=cfg.input_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))

        valset = TotalText(data_root='data/total-text',
                           ignore_list=None,
                           is_training=False,
                           transform=BaseTransform(size=cfg.input_size,
                                                   mean=cfg.means,
                                                   std=cfg.stds))

    elif cfg.dataset == 'synth-text':
        trainset = SynthText(data_root='data/SynthText',
                             is_training=True,
                             transform=Augmentation(size=cfg.input_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))
        valset = None
    else:
        pass

    train_loader = data.DataLoader(trainset,
                                   batch_size=cfg.batch_size,
                                   shuffle=True,
                                   num_workers=cfg.num_workers)
    # DataLoader时,设置pin_memory = True,则意味着生成的Tensor数据最开始是属于内存中的锁页内存,这样将内存的Tensor转义到GPU的显存就会更快一些。
    if valset:
        val_loader = data.DataLoader(valset,
                                     batch_size=cfg.batch_size,
                                     shuffle=False,
                                     num_workers=cfg.num_workers)
    else:
        valset = None

    log_dir = os.path.join(
        cfg.log_dir,
        datetime.now().strftime('%b%d_%H-%M-%S_') + cfg.exp_name)
    logger = LogSummary(log_dir)

    # Model
    # 载入模型,backbone默认为vgg
    model = TextNet(is_training=True, backbone=cfg.net)
    if cfg.mgpu:
        model = nn.DataParallel(model)

    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True
    # 载入继续训练
    if cfg.resume:
        load_model(model, cfg.resume)

    # loss定义主要有loss_tr, loss_tcl, loss_radii, loss_sin, loss_cos
    criterion = TextLoss()
    lr = cfg.lr
    # Adam梯度优化
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)

    if cfg.dataset == 'synth-text':
        # 固定学习率。当last_epoch = -1时,将初始lr设置为lr。
        scheduler = FixLR(optimizer)
    else:
        # 动态学习率.每step_size步, lr*0.1
        scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)

    print('Start training TextSnake.')
    for epoch in range(cfg.start_epoch, cfg.max_epoch):
        train(model, train_loader, criterion, scheduler, optimizer, epoch,
              logger)
        if valset:
            validation(model, val_loader, criterion, epoch, logger)

    print('End.')
Esempio n. 9
0
    os.mkdir(cfg.vis_dir)

trainset = TotalText_txt("data/total-text/Images/Train/",
                         "gt/Train/",
                         ignore_list=None,
                         is_training=True,
                         transform=Augmentation(size=cfg.input_size,
                                                mean=cfg.means,
                                                std=cfg.stds))

valset = TotalText_txt("data/total-text/Images/Test/",
                       "gt/Test/",
                       ignore_list=None,
                       is_training=False,
                       transform=BaseTransform(size=cfg.input_size,
                                               mean=cfg.means,
                                               std=cfg.stds))

train_loader = data.DataLoader(trainset,
                               batch_size=cfg.batch_size,
                               shuffle=True,
                               num_workers=cfg.num_workers)

if valset:
    val_loader = data.DataLoader(valset,
                                 batch_size=cfg.batch_size,
                                 shuffle=False,
                                 num_workers=cfg.num_workers)
else:
    valset = None
Esempio n. 10
0
def main():

    # global lr

    trainset = CustomText(data_root='data/{}'.format(cfg.dataset),
                          is_training=True,
                          transform=Augmentation(size=cfg.input_size,
                                                 mean=cfg.means,
                                                 std=cfg.stds))

    valset = CustomText(data_root='data/{}'.format(cfg.dataset),
                        is_training=False,
                        transform=BaseTransform(size=cfg.input_size,
                                                mean=cfg.means,
                                                std=cfg.stds))

    train_loader = data.DataLoader(trainset,
                                   batch_size=cfg.batch_size,
                                   shuffle=True,
                                   num_workers=cfg.num_workers,
                                   collate_fn=my_collate)
    if valset:
        val_loader = data.DataLoader(valset,
                                     batch_size=cfg.batch_size,
                                     shuffle=False,
                                     num_workers=cfg.num_workers,
                                     collate_fn=my_collate)
    else:
        valset = None

    log_dir = os.path.join(
        cfg.log_dir,
        datetime.now().strftime('%b%d_%H-%M-%S_') + cfg.exp_name)
    logger = LogSummary(log_dir)

    # Model
    model = TextNet(is_training=True, backbone=cfg.net)
    if cfg.mgpu:
        model = nn.DataParallel(model)

    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True

    if cfg.resume:
        load_model(model, cfg.resume)

    criterion = TextLoss()
    # lr = cfg.lr
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)

    if cfg.dataset == 'synth-text':
        scheduler = FixLR(optimizer)
    else:
        scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)

    print('Start training TextSnake.')
    writer = SummaryWriter(logdir='process')
    for epoch in range(cfg.start_epoch, cfg.max_epoch):

        train(model, train_loader, criterion, scheduler, optimizer, epoch,
              logger)
        if valset:
            validation_loss = validation(model, val_loader, criterion, epoch,
                                         logger)

            writer.add_scalar('Validation Loss', validation_loss, epoch)

    print('End.')
Esempio n. 11
0
        ignore_list=
        None,  #'/workspace/mnt/group/ocr/qiutairu/dataset/ArT_train/ignore_list.txt',
        is_training=True,
        transform=transform)

    for idx in range(len(trainset)):
        img, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map, meta = trainset[
            idx]
        if img.shape[0] != 3:
            print(idx, img.shape)

    testset = TotalText(
        data_root='/home/shf/fudan_ocr_system/datasets/ICDAR19/',
        ignore_list=None,
        is_training=False,
        transform=BaseTransform(size=512, mean=means, std=stds))

    for idx in range(len(testset)):
        img, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map, meta = testset[
            idx]
        if img.shape[0] != 3:
            print(idx, img.shape)

    # path = '/workspace/mnt/group/ocr/qiutairu/dataset/ArT_train/train_images'
    # files = os.listdir(path)
    #
    # for file in files:
    #     image = pil_load_img(os.path.join(path, file))
    #     if image.shape[2] != 3:
    #         print(file, image.shape)
Esempio n. 12
0
    def Train(self):
        cfg.max_epoch = self.system_dict["params"]["max_epoch"]
        cfg.means = self.system_dict["params"]["means"]
        cfg.stds = self.system_dict["params"]["stds"]
        cfg.log_dir = self.system_dict["params"]["log_dir"]
        cfg.exp_name = self.system_dict["params"]["exp_name"]
        cfg.net = self.system_dict["params"]["net"]
        cfg.resume = self.system_dict["params"]["resume"]
        cfg.cuda = self.system_dict["params"]["cuda"]
        cfg.mgpu = self.system_dict["params"]["mgpu"]
        cfg.save_dir = self.system_dict["params"]["save_dir"]
        cfg.vis_dir = self.system_dict["params"]["vis_dir"]
        cfg.input_channel = self.system_dict["params"]["input_channel"]

        cfg.lr = self.system_dict["params"]["lr"]
        cfg.weight_decay = self.system_dict["params"]["weight_decay"]
        cfg.gamma = self.system_dict["params"]["gamma"]
        cfg.momentum = self.system_dict["params"]["momentum"]
        cfg.optim = self.system_dict["params"]["optim"]
        cfg.display_freq = self.system_dict["params"]["display_freq"]
        cfg.viz_freq = self.system_dict["params"]["viz_freq"]
        cfg.save_freq = self.system_dict["params"]["save_freq"]
        cfg.log_freq = self.system_dict["params"]["log_freq"]

        cfg.batch_size = self.system_dict["params"]["batch_size"]
        cfg.rescale = self.system_dict["params"]["rescale"]
        cfg.checkepoch = self.system_dict["params"]["checkepoch"]

        cfg.val_freq = 1000
        cfg.start_iter = 0
        cfg.loss = "CrossEntropyLoss"
        cfg.pretrain = False
        cfg.verbose = True
        cfg.viz = True
        cfg.lr_adjust = "step"  #fix, step
        cfg.stepvalues = []
        cfg.step_size = cfg.max_epoch // 2

        if cfg.cuda and torch.cuda.is_available():
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
            cudnn.benchmark = True
            cfg.device = torch.device("cuda")
        else:
            torch.set_default_tensor_type('torch.FloatTensor')
            cfg.device = torch.device("cpu")

        # Create weights saving directory
        if not os.path.exists(cfg.save_dir):
            os.mkdir(cfg.save_dir)

        # Create weights saving directory of target model
        model_save_path = os.path.join(cfg.save_dir, cfg.exp_name)

        if not os.path.exists(model_save_path):
            os.mkdir(model_save_path)

        if not os.path.exists(cfg.vis_dir):
            os.mkdir(cfg.vis_dir)

        if (self.system_dict["params"]["annotation_type"] == "text"):
            trainset = TotalText_txt(
                self.system_dict["params"]["train_img_folder"],
                self.system_dict["params"]["train_anno_folder"],
                ignore_list=None,
                is_training=True,
                transform=Augmentation(size=cfg.input_size,
                                       mean=cfg.means,
                                       std=cfg.stds))
            train_loader = data.DataLoader(trainset,
                                           batch_size=cfg.batch_size,
                                           shuffle=True,
                                           num_workers=cfg.num_workers)

            if (self.system_dict["params"]["val_dataset"]):
                valset = TotalText_txt(
                    self.system_dict["params"]["val_img_folder"],
                    self.system_dict["params"]["val_anno_folder"],
                    ignore_list=None,
                    is_training=False,
                    transform=BaseTransform(size=cfg.input_size,
                                            mean=cfg.means,
                                            std=cfg.stds))
                val_loader = data.DataLoader(valset,
                                             batch_size=cfg.batch_size,
                                             shuffle=False,
                                             num_workers=cfg.num_workers)
            else:
                valset = None

        elif (self.system_dict["params"]["annotation_type"] == "mat"):
            trainset = TotalText_mat(
                self.system_dict["params"]["train_img_folder"],
                self.system_dict["params"]["train_anno_folder"],
                ignore_list=None,
                is_training=True,
                transform=Augmentation(size=cfg.input_size,
                                       mean=cfg.means,
                                       std=cfg.stds))
            train_loader = data.DataLoader(trainset,
                                           batch_size=cfg.batch_size,
                                           shuffle=True,
                                           num_workers=cfg.num_workers)

            if (self.system_dict["params"]["val_dataset"]):
                valset = TotalText_mat(
                    self.system_dict["params"]["val_img_folder"],
                    self.system_dict["params"]["val_anno_folder"],
                    ignore_list=None,
                    is_training=False,
                    transform=BaseTransform(size=cfg.input_size,
                                            mean=cfg.means,
                                            std=cfg.stds))
                val_loader = data.DataLoader(valset,
                                             batch_size=cfg.batch_size,
                                             shuffle=False,
                                             num_workers=cfg.num_workers)
            else:
                valset = None

        log_dir = os.path.join(
            cfg.log_dir,
            datetime.now().strftime('%b%d_%H-%M-%S_') + cfg.exp_name)
        logger = LogSummary(log_dir)

        model = TextNet(is_training=True, backbone=cfg.net)
        if cfg.mgpu:
            model = nn.DataParallel(model)

        model = model.to(cfg.device)

        if cfg.cuda:
            cudnn.benchmark = True

        if cfg.resume:
            self.load_model(model, cfg.resume)

        criterion = TextLoss()
        lr = cfg.lr
        if (cfg.optim == "adam"):
            optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
        else:
            optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)

        scheduler = lr_scheduler.StepLR(optimizer,
                                        step_size=cfg.step_size,
                                        gamma=cfg.gamma)

        train_step = 0
        print('Start training TextSnake.')
        for epoch in range(cfg.start_epoch, cfg.max_epoch):
            train_step = self.train(model, train_loader, criterion, scheduler,
                                    optimizer, epoch, logger, train_step)
            if valset:
                self.validation(model, val_loader, criterion, epoch, logger)
            self.save_model(model, "final", scheduler.get_lr(), optimizer)
Esempio n. 13
0
            if polygon.text != '#':
                sideline1, sideline2, center_points, radius = polygon.disk_cover(
                    n_disk=cfg.n_disk)
                self.make_text_center_line(sideline1, sideline2, center_points,
                                           radius, tcl_mask, radius_map,
                                           sin_map, cos_map)
        tr_mask, train_mask = self.make_text_region(image, polygons)
        # to pytorch channel sequence
        image = image.transpose(2, 0, 1)

        meta = {
            'image_id': image_id,
            'image_path': image_path,
            'Height': H,
            'Width': W
        }
        return image, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map, meta

    def __len__(self):
        return len(self.image_list)


if __name__ == '__main__':
    import os
    from util.augmentation import BaseTransform

    transform = BaseTransform(size=512, mean=0.5, std=0.5)
    trainset = TotalText(data_root='data/total-text',
                         ignore_list='./ignore_list.txt',
                         is_training=True,
                         transform=transform)
Esempio n. 14
0
def main():

    global lr
    trainset = None
    valset = None

    if cfg.dataset_name == 'total_text':

        trainset = TotalText(data_root=cfg.dataset_root,
                             ignore_list=None,
                             is_training=True,
                             transform=Augmentation(size=cfg.img_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))

        valset = TotalText(data_root=cfg.dataset_root,
                           ignore_list=None,
                           is_training=False,
                           transform=BaseTransform(size=cfg.img_size,
                                                   mean=cfg.means,
                                                   std=cfg.stds))

    elif cfg.dataset_name == 'coco_text':
        trainset = COCO_Text(data_root=cfg.dataset_root,
                             is_training=True,
                             transform=Augmentation(size=cfg.img_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))
        valset = None

    elif cfg.dataset_name == 'synth_text':
        trainset = SynthText(
            data_root='/home/andrew/Documents/Dataset/SynthText/SynthText',
            is_training=True,
            transform=Augmentation(size=cfg.img_size,
                                   mean=cfg.means,
                                   std=cfg.stds))
        valset = None
    else:
        pass

    train_loader = data.DataLoader(trainset,
                                   batch_size=cfg.batch_size,
                                   shuffle=True,
                                   num_workers=cfg.training_num_workers,
                                   pin_memory=True,
                                   timeout=10)
    if valset:
        val_loader = data.DataLoader(valset,
                                     batch_size=cfg.batch_size,
                                     shuffle=False,
                                     num_workers=cfg.training_num_workers)
    else:
        val_loader = None

    model = UHT_Net(pretrained=True)
    if cfg.multi_gpu:
        model = nn.DataParallel(model)

    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True

    if cfg.resume != "":
        load_model(model, cfg.resume)
    criterion = UHT_Loss()
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.learning_rate)

    scheduler = lr_scheduler.StepLR(optimizer,
                                    step_size=len(train_loader) *
                                    cfg.decay_epoch,
                                    gamma=cfg.decay_rate)

    print('Start training Model.')
    for epoch in range(cfg.start_epoch, cfg.end_epoch):
        train(model, train_loader, criterion, scheduler, optimizer, epoch)
        if valset:
            validation(model, val_loader, criterion)

    print('End.')
Esempio n. 15
0
    def Predict(self,
                image_path,
                output_img_path="output.jpg",
                output_txt_path="output.txt",
                tr_thresh=0.4,
                tcl_thresh=0.4):

        cfg = self.system_dict["local"]["cfg"]
        model = self.system_dict["local"]["model"]

        start = time.time()
        image = pil_load_img(image_path)

        transform = BaseTransform(size=cfg.input_size,
                                  mean=cfg.means,
                                  std=cfg.stds)

        H, W, _ = image.shape

        image, polygons = transform(image)

        # to pytorch channel sequence
        image = image.transpose(2, 0, 1)

        meta = {
            'image_id': 0,
            'image_path': image_path,
            'Height': H,
            'Width': W
        }
        image = torch.from_numpy(np.expand_dims(image, axis=0))
        image = to_device(image)
        if (self.system_dict["local"]["cfg"].cuda):
            torch.cuda.synchronize()

        end = time.time()
        print("Image loading time: {}".format(end - start))

        start = time.time()
        detector = TextDetector(model,
                                tr_thresh=tr_thresh,
                                tcl_thresh=tcl_thresh)
        # get detection result
        contours, output = detector.detect(image)

        torch.cuda.synchronize()
        end = time.time()

        print("Inference time - {}".format(end - start))

        start = time.time()
        tr_pred, tcl_pred = output['tr'], output['tcl']
        img_show = image[0].permute(1, 2, 0).cpu().numpy()
        img_show = ((img_show * cfg.stds + cfg.means) * 255).astype(np.uint8)

        img_show, contours = rescale_result(img_show, contours, H, W)

        pred_vis = visualize_detection(img_show, contours)
        cv2.imwrite(output_img_path, pred_vis)

        # write to file
        self.write_to_file(contours, output_txt_path)
        end = time.time()

        print("Writing output time - {}".format(end - start))