Example #1
0
def main():

    testset = TotalText(
        data_root='data/total-text',
        ignore_list=None,
        is_training=False,
        transform=BaseTransform(size=cfg.input_size, mean=cfg.means, std=cfg.stds)
    )
    test_loader = data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=cfg.num_workers)

    # Model
    model = TextNet()
    model_path = os.path.join(cfg.save_dir, cfg.exp_name, \
              'textsnake_{}_{}.pth'.format(model.backbone_name, cfg.checkepoch))
    load_model(model, model_path)

    # copy to cuda
    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True
    detector = TextDetector(tr_thresh=cfg.tr_thresh, tcl_thresh=cfg.tcl_thresh)

    print('Start testing TextSnake.')
    output_dir = os.path.join(cfg.output_dir, cfg.exp_name)
    inference(model, detector, test_loader, output_dir)

    # compute DetEval
    print('Computing DetEval in {}/{}'.format(cfg.output_dir, cfg.exp_name))
    subprocess.call(['python', 'dataset/total_text/Evaluation_Protocol/Python_scripts/Deteval.py', args.exp_name])
    print('End.')
Example #2
0
def main():

    testset = TotalText(
        data_root='/home/shf/fudan_ocr_system/datasets/totaltext',
        ignore_list=None,
        is_training=False,
        transform=BaseTransform(size=cfg.input_size, mean=cfg.means, std=cfg.stds)
    )
    test_loader = data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=cfg.num_workers)

    # Model
    model = TextNet()
    model_path = os.path.join(cfg.save_dir, cfg.exp_name, \
              'textsnake_{}_{}.pth'.format(model.backbone_name, cfg.checkepoch))
    load_model(model, model_path)

    # copy to cuda
    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True
    detector = TextDetector()

    print('Start testing TextSnake.')

    inference(model, detector, test_loader)

    print('End.')
Example #3
0
def main():

    global lr

    if cfg.dataset == 'total-text':

        trainset = TotalText(data_root='data/total-text',
                             ignore_list=None,
                             is_training=True,
                             transform=Augmentation(size=cfg.input_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))

        valset = TotalText(data_root='data/total-text',
                           ignore_list=None,
                           is_training=False,
                           transform=BaseTransform(size=cfg.input_size,
                                                   mean=cfg.means,
                                                   std=cfg.stds))
    else:
        pass

    train_loader = data.DataLoader(trainset,
                                   batch_size=cfg.batch_size,
                                   shuffle=True,
                                   num_workers=cfg.num_workers)
    val_loader = data.DataLoader(valset,
                                 batch_size=cfg.batch_size,
                                 shuffle=False,
                                 num_workers=cfg.num_workers)

    log_dir = os.path.join(
        cfg.log_dir,
        datetime.now().strftime('%b%d_%H-%M-%S_') + cfg.exp_name)
    mkdirs(log_dir)
    logger = LogSummary(log_dir)

    # Model
    model = TextNet()
    if cfg.mgpu:
        model = nn.DataParallel(model, device_ids=cfg.gpu_ids)

    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True

    criterion = TextLoss()
    lr = cfg.lr
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.94)

    print('Start training TextSnake.')

    for epoch in range(cfg.start_epoch, cfg.max_epoch):
        train(model, train_loader, criterion, scheduler, optimizer, epoch,
              logger)
        with torch.no_grad():
            validation(model, val_loader, criterion, epoch, logger)

    print('End.')
Example #4
0
def main():

    testset = DeployDataset(image_root=cfg.img_root,
                            transform=BaseTransform(size=cfg.input_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))
    print(cfg)
    test_loader = data.DataLoader(testset,
                                  batch_size=1,
                                  shuffle=False,
                                  num_workers=cfg.num_workers)

    # Model
    model = TextNet(is_training=False, backbone=cfg.net)
    model_path = os.path.join(cfg.save_dir, cfg.exp_name, \
              'textsnake_{}_{}.pth'.format(model.backbone_name, cfg.checkepoch))
    model.load_model(model_path)

    # copy to cuda
    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True
    detector = TextDetector(model,
                            tr_thresh=cfg.tr_thresh,
                            tcl_thresh=cfg.tcl_thresh)

    print('Start testing TextSnake.')
    output_dir = os.path.join(cfg.output_dir, cfg.exp_name)
    inference(detector, test_loader, output_dir)
Example #5
0
    def Model_Params(self, model_type="vgg", model_path=None, use_gpu=True):
        self.system_dict["local"]["net"] = model_type
        self.system_dict["local"]["model_path"] = model_path
        self.system_dict["local"]["cuda"] = use_gpu

        self.system_dict["local"]["cfg"] = cfg

        self.system_dict["local"]["cfg"].net = self.system_dict["local"]["net"]
        self.system_dict["local"]["cfg"].cuda = self.system_dict["local"][
            "cuda"]
        self.system_dict["local"]["cfg"].means = self.system_dict["local"][
            "means"]
        self.system_dict["local"]["cfg"].stds = self.system_dict["local"][
            "stds"]
        self.system_dict["local"]["cfg"].input_size = self.system_dict[
            "local"]["input_size"]

        model = TextNet(is_training=False,
                        backbone=self.system_dict["local"]["cfg"].net)
        model.load_model(model_path)

        # copy to cuda
        if (self.system_dict["local"]["cfg"].cuda):
            cudnn.benchmark = True
            self.system_dict["local"]["cfg"].device = torch.device("cuda")
        else:
            self.system_dict["local"]["cfg"].device = torch.device("cpu")

        self.system_dict["local"]["model"] = model.to(
            self.system_dict["local"]["cfg"].device)
Example #6
0
def main():

    if cfg.dataset == 'total-text':

        trainset = TotalText(
            data_root='data/total-text',
            ignore_list='./dataset/total_text/ignore_list.txt',
            is_training=True,
            transform=Augmentation(size=cfg.input_size,
                                   mean=cfg.means,
                                   std=cfg.stds))

        valset = TotalText(data_root='data/total-text',
                           ignore_list=None,
                           is_training=False,
                           transform=BaseTransform(size=cfg.input_size,
                                                   mean=cfg.means,
                                                   std=cfg.stds))
    else:
        pass

    train_loader = data.DataLoader(trainset,
                                   batch_size=cfg.batch_size,
                                   shuffle=True,
                                   num_workers=cfg.num_workers)
    val_loader = data.DataLoader(valset,
                                 batch_size=cfg.batch_size,
                                 shuffle=False,
                                 num_workers=cfg.num_workers)

    # Model
    if cfg.mgpu:
        model = TextNet()
    # model = nn.DataParallel(model, device_ids=cfg.gpu_ids)

    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True

    criterion = TextLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.94)

    print('Start training TextSnake.')

    for epoch in range(cfg.start_epoch, cfg.max_epoch):
        train(model, train_loader, criterion, scheduler, optimizer, epoch)
        validation(model, val_loader, criterion)

    print('End.')
Example #7
0
def main(vis_dir_path):

    osmkdir(vis_dir_path)
    if cfg.exp_name == "Totaltext":
        testset = TotalText(data_root='data/total-text-mat',
                            ignore_list=None,
                            is_training=False,
                            transform=BaseTransform(size=cfg.test_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))

    elif cfg.exp_name == "Ctw1500":
        testset = Ctw1500Text(data_root='data/ctw1500',
                              is_training=False,
                              transform=BaseTransform(size=cfg.test_size,
                                                      mean=cfg.means,
                                                      std=cfg.stds))
    elif cfg.exp_name == "TD500":
        testset = TD500Text(data_root='data/TD500',
                            is_training=False,
                            transform=BaseTransform(size=cfg.test_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))
    else:
        print("{} is not justify".format(cfg.exp_name))

    test_loader = data.DataLoader(testset,
                                  batch_size=1,
                                  shuffle=False,
                                  num_workers=cfg.num_workers)

    # Model
    model = TextNet(is_training=False, backbone=cfg.net)
    model_path = os.path.join(cfg.save_dir, cfg.exp_name,
                              'TextGraph_{}.pth'.format(model.backbone_name))
    model.load_model(model_path)

    # copy to cuda
    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True

    if cfg.graph_link:
        detector = TextDetector_graph(model)

    print('Start testing TextGraph.')
    output_dir = os.path.join(cfg.output_dir, cfg.exp_name)
    inference(detector, test_loader, output_dir)
Example #8
0
def main():
    testset = TotalText(
        data_root='/home/shf/fudan_ocr_system/datasets/ICDAR19/',
        ignore_list=
        '/home/shf/fudan_ocr_system/datasets/ICDAR19/ignore_list.txt',
        is_training=False,
        transform=EvalTransform(size=1280, mean=cfg.means, std=cfg.stds)
        # transform=BaseTransform(size=1280, mean=cfg.means, std=cfg.stds)
    )
    test_loader = data.DataLoader(testset,
                                  batch_size=1,
                                  shuffle=False,
                                  num_workers=cfg.num_workers)

    # Model
    model = TextNet(backbone=cfg.backbone, output_channel=7)
    model_path = os.path.join(cfg.save_dir, cfg.exp_name, \
              'textsnake_{}_{}.pth'.format(model.backbone_name, cfg.checkepoch))
    load_model(model, model_path)

    # copy to cuda
    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True
    detector = TextDetector(tcl_conf_thresh=0.3, tr_conf_thresh=1.0)  # 0.3

    # check vis_dir and output_dir exist
    if cfg.viz:
        if not os.path.exists(cfg.vis_dir):
            os.mkdir(cfg.vis_dir)
    if not os.path.exists(cfg.output_dir):
        os.mkdir(cfg.output_dir)

    print('Start testing TextSnake.')

    inference(model, detector, test_loader)

    print('End.')
Example #9
0
def main():
    global total_iter
    data_root = os.path.join('/home/shf/fudan_ocr_system/datasets/',
                             cfg.dataset)

    trainset = TotalText(data_root=data_root,
                         ignore_list=os.path.join(data_root,
                                                  'ignore_list.txt'),
                         is_training=True,
                         transform=NewAugmentation(size=cfg.input_size,
                                                   mean=cfg.means,
                                                   std=cfg.stds,
                                                   maxlen=1280,
                                                   minlen=512))

    train_loader = data.DataLoader(trainset,
                                   batch_size=cfg.batch_size,
                                   shuffle=True,
                                   num_workers=cfg.num_workers)

    # Model
    model = TextNet(backbone=cfg.backbone, output_channel=7)
    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True

    criterion = TextLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
    # scheduler = lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.94)
    # if cfg.dataset == 'ArT_train':
    #     scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[10000, 50000], gamma=0.1)
    # elif cfg.dataset == 'LSVT_full_train':
    #     scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[10000, 50000], gamma=0.1)

    # load model if resume
    if cfg.resume is not None:
        global total_iter
        checkpoint = load_model(cfg.resume)
        cfg.start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optim'])
        # scheduler.load_state_dict(checkpoint['scheduler'])
        total_iter = checkpoint['epoch'] * len(train_loader)

    if not os.path.exists(os.path.join(cfg.summary_dir, cfg.exp_name)):
        os.mkdir(os.path.join(cfg.summary_dir, cfg.exp_name))
    summary_writer = SummaryWriter(
        log_dir=os.path.join(cfg.summary_dir, cfg.exp_name))

    print('Start training TextSnake.')

    for epoch in range(cfg.start_epoch, cfg.max_epoch):
        adjust_learning_rate(optimizer, total_iter)
        train(model, train_loader, criterion, None, optimizer, epoch,
              summary_writer)

    print('End.')
Example #10
0
def main():

    global lr
    if cfg.exp_name == 'Totaltext':
        trainset = TotalText(data_root='data/total-text-mat',
                             ignore_list=None,
                             is_training=True,
                             transform=Augmentation(size=cfg.input_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))
        # valset = TotalText(
        #     data_root='data/total-text-mat',
        #     ignore_list=None,
        #     is_training=False,
        #     transform=BaseTransform(size=cfg.input_size, mean=cfg.means, std=cfg.stds)
        # )
        valset = None

    elif cfg.exp_name == 'Synthtext':
        trainset = SynthText(data_root='data/SynthText',
                             is_training=True,
                             transform=Augmentation(size=cfg.input_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))
        valset = None

    elif cfg.exp_name == 'Ctw1500':
        trainset = Ctw1500Text(data_root='data/ctw1500',
                               is_training=True,
                               transform=Augmentation(size=cfg.input_size,
                                                      mean=cfg.means,
                                                      std=cfg.stds))
        valset = None

    elif cfg.exp_name == 'Icdar2015':
        trainset = Icdar15Text(data_root='data/Icdar2015',
                               is_training=True,
                               transform=Augmentation(size=cfg.input_size,
                                                      mean=cfg.means,
                                                      std=cfg.stds))
        valset = None
    elif cfg.exp_name == 'MLT2017':
        trainset = Mlt2017Text(data_root='data/MLT2017',
                               is_training=True,
                               transform=Augmentation(size=cfg.input_size,
                                                      mean=cfg.means,
                                                      std=cfg.stds))
        valset = None

    elif cfg.exp_name == 'TD500':
        trainset = TD500Text(data_root='data/TD500',
                             is_training=True,
                             transform=Augmentation(size=cfg.input_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))
        valset = None

    else:
        print("dataset name is not correct")

    train_loader = data.DataLoader(trainset,
                                   batch_size=cfg.batch_size,
                                   shuffle=True,
                                   num_workers=cfg.num_workers,
                                   pin_memory=True)

    log_dir = os.path.join(
        cfg.log_dir,
        datetime.now().strftime('%b%d_%H-%M-%S_') + cfg.exp_name)
    logger = LogSummary(log_dir)

    # Model
    model = TextNet(backbone=cfg.net, is_training=True)
    if cfg.mgpu:
        model = nn.DataParallel(model)

    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True

    if cfg.resume:
        load_model(model, cfg.resume)

    criterion = TextLoss()

    lr = cfg.lr
    moment = cfg.momentum
    if cfg.optim == "Adam" or cfg.exp_name == 'Synthtext':
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=moment)

    if cfg.exp_name == 'Synthtext':
        scheduler = FixLR(optimizer)
    else:
        scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.90)

    print('Start training TextGraph.')
    for epoch in range(cfg.start_epoch, cfg.start_epoch + cfg.max_epoch + 1):
        scheduler.step()
        train(model, train_loader, criterion, scheduler, optimizer, epoch,
              logger)

    print('End.')

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
Example #11
0
def train_TextSnake(config_file):

    import sys
    sys.path.append('./detection_model/TextSnake_pytorch')

    import os
    import time

    import torch
    import torch.backends.cudnn as cudnn
    import torch.utils.data as data
    from torch.optim import lr_scheduler
    import torchvision.utils as vutils
    from tensorboardX import SummaryWriter

    from dataset.total_text import TotalText
    from network.loss import TextLoss
    from network.textnet import TextNet
    from util.augmentation import EvalTransform, NewAugmentation
    from util.config import config as cfg, update_config, print_config, init_config
    from util.misc import AverageMeter
    from util.misc import mkdirs, to_device
    from util.option import BaseOptions
    from util.visualize import visualize_network_output

    from yacs.config import CfgNode as CN

    global total_iter
    total_iter = 0

    def read_config_file(config_file):
        """
        read config information form yaml file
        """
        f = open(config_file)
        opt = CN.load_cfg(f)
        return opt

    opt = read_config_file(config_file)

    def adjust_learning_rate(optimizer, i):
        if 0 <= i * opt.batch_size < 100000:
            lr = opt.lr
        elif 100000 <= i * opt.batch_size < 400000:
            lr = opt.lr * 0.1
        else:
            lr = opt.lr * 0.1 * (0.94**(
                (i * opt.batch_size - 300000) // 100000))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    def adjust_requires_grad(model, i):
        if 0 <= i < 4000:
            for name, param in model.named_parameters():
                if name == 'conv1.0.weight' or name == 'conv1.0.bias' or \
                   name == 'conv1.1.weight' or name == 'conv1.1.bias':
                    param.requires_grad = False
        else:
            for name, param in model.named_parameters():
                if name == 'conv1.0.weight' or name == 'conv1.0.bias' or \
                   name == 'conv1.1.weight' or name == 'conv1.1.bias':
                    param.requires_grad = True

    def save_model(model, optimizer, scheduler, epoch):
        save_dir = os.path.join(opt.save_dir, opt.exp_name)
        if not os.path.exists(save_dir):
            mkdirs(save_dir)

        save_path = os.path.join(
            save_dir, 'textsnake_{}_{}.pth'.format(model.backbone_name, epoch))
        print('Saving to {}.'.format(save_path))
        state_dict = {
            'epoch': epoch,
            'model': model.state_dict(),
            'optim': optimizer.state_dict()
            # 'scheduler': scheduler.state_dict()
        }
        torch.save(state_dict, save_path)

    def load_model(save_path):
        print('Loading from {}.'.format(save_path))
        checkpoint = torch.load(save_path)
        return checkpoint

    def train(model, train_loader, criterion, scheduler, optimizer, epoch,
              summary_writer):

        start = time.time()
        losses = AverageMeter()
        batch_time = AverageMeter()
        data_time = AverageMeter()
        end = time.time()
        model.train()
        global total_iter

        for i, (img, train_mask, tr_mask, tcl_mask, radius_map, sin_map,
                cos_map, meta) in enumerate(train_loader):
            data_time.update(time.time() - end)

            img, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map = to_device(
                img, train_mask, tr_mask, tcl_mask, radius_map, sin_map,
                cos_map)

            output = model(img)
            tr_loss, tcl_loss, sin_loss, cos_loss, radii_loss = \
                criterion(output, tr_mask, tcl_mask, sin_map, cos_map, radius_map, train_mask, total_iter)
            loss = tr_loss + tcl_loss + sin_loss + cos_loss + radii_loss

            # backward
            # scheduler.step()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            losses.update(loss.item())

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if opt.viz and i < opt.vis_num:
                visualize_network_output(output,
                                         tr_mask,
                                         tcl_mask,
                                         prefix='train_{}'.format(i))

            if i % opt.display_freq == 0:
                print(
                    'Epoch: [ {} ][ {:03d} / {:03d} ] - Loss: {:.4f} - tr_loss: {:.4f} - tcl_loss: {:.4f} - sin_loss: {:.4f} - cos_loss: {:.4f} - radii_loss: {:.4f} - {:.2f}s/step'
                    .format(epoch, i, len(train_loader), loss.item(),
                            tr_loss.item(), tcl_loss.item(), sin_loss.item(),
                            cos_loss.item(), radii_loss.item(),
                            batch_time.avg))

            # write summary
            if total_iter % opt.summary_freq == 0:
                print('Summary in {}'.format(
                    os.path.join(opt.summary_dir, opt.exp_name)))
                tr_pred = output[:, 0:2].softmax(dim=1)[:, 1:2]
                tcl_pred = output[:, 2:4].softmax(dim=1)[:, 1:2]
                summary_writer.add_image('input_image',
                                         vutils.make_grid(img, normalize=True),
                                         total_iter)
                summary_writer.add_image(
                    'tr/tr_pred',
                    vutils.make_grid(tr_pred * 255, normalize=True),
                    total_iter)
                summary_writer.add_image(
                    'tr/tr_mask',
                    vutils.make_grid(
                        torch.unsqueeze(tr_mask * train_mask, 1) * 255),
                    total_iter)
                summary_writer.add_image(
                    'tcl/tcl_pred',
                    vutils.make_grid(tcl_pred * 255, normalize=True),
                    total_iter)
                summary_writer.add_image(
                    'tcl/tcl_mask',
                    vutils.make_grid(
                        torch.unsqueeze(tcl_mask * train_mask, 1) * 255),
                    total_iter)
                summary_writer.add_scalar('learning_rate',
                                          optimizer.param_groups[0]['lr'],
                                          total_iter)
                summary_writer.add_scalar('model/tr_loss', tr_loss.item(),
                                          total_iter)
                summary_writer.add_scalar('model/tcl_loss', tcl_loss.item(),
                                          total_iter)
                summary_writer.add_scalar('model/sin_loss', sin_loss.item(),
                                          total_iter)
                summary_writer.add_scalar('model/cos_loss', cos_loss.item(),
                                          total_iter)
                summary_writer.add_scalar('model/radii_loss',
                                          radii_loss.item(), total_iter)
                summary_writer.add_scalar('model/loss', loss.item(),
                                          total_iter)

            total_iter += 1

        print('Speed: {}s /step, {}s /epoch'.format(batch_time.avg,
                                                    time.time() - start))

        if epoch % opt.save_freq == 0:
            save_model(model, optimizer, scheduler, epoch)

        print('Training Loss: {}'.format(losses.avg))

    def validation(model, valid_loader, criterion):
        """
        print a series of loss information
        """

        model.eval()
        losses = AverageMeter()

        for i, (img, train_mask, tr_mask, tcl_mask, radius_map, sin_map,
                cos_map, meta) in enumerate(valid_loader):
            print(meta['image_id'])
            img, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map = to_device(
                img, train_mask, tr_mask, tcl_mask, radius_map, sin_map,
                cos_map)

            output = model(img)

            tr_loss, tcl_loss, sin_loss, cos_loss, radii_loss = \
                criterion(output, tr_mask, tcl_mask, sin_map, cos_map, radius_map, train_mask)
            loss = tr_loss + tcl_loss + sin_loss + cos_loss + radii_loss
            losses.update(loss.item())

            if opt.viz and i < opt.vis_num:
                visualize_network_output(output,
                                         tr_mask,
                                         tcl_mask,
                                         prefix='val_{}'.format(i))

            if i % opt.display_freq == 0:
                print(
                    'Validation: - Loss: {:.4f} - tr_loss: {:.4f} - tcl_loss: {:.4f} - sin_loss: {:.4f} - cos_loss: {:.4f} - radii_loss: {:.4f}'
                    .format(loss.item(), tr_loss.item(), tcl_loss.item(),
                            sin_loss.item(), cos_loss.item(),
                            radii_loss.item()))

        print('Validation Loss: {}'.format(losses.avg))

    # parse arguments

    torch.cuda.set_device(opt.num_device)
    option = BaseOptions(config_file)
    args = option.initialize()

    init_config(opt, config_file)
    update_config(opt, args)
    print_config(opt)

    data_root = os.path.join(opt.data_root, opt.dataset)

    trainset = TotalText(data_root=data_root,
                         ignore_list=os.path.join(data_root,
                                                  'ignore_list.txt'),
                         is_training=True,
                         transform=NewAugmentation(size=opt.input_size,
                                                   mean=opt.means,
                                                   std=opt.stds,
                                                   maxlen=1280,
                                                   minlen=512))

    train_loader = data.DataLoader(trainset,
                                   batch_size=opt.batch_size,
                                   shuffle=True,
                                   num_workers=opt.num_workers)

    # Model
    model = TextNet(backbone=opt.backbone, output_channel=7)
    model = model.to(opt.device)
    if opt.cuda:
        cudnn.benchmark = True

    criterion = TextLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
    # scheduler = lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.94)
    # if opt.dataset == 'ArT_train':
    #     scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[10000, 50000], gamma=0.1)
    # elif opt.dataset == 'LSVT_full_train':
    #     scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[10000, 50000], gamma=0.1)

    # load model if resume
    if opt.resume is not False:
        checkpoint = load_model(opt.resume)
        opt.start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optim'])
        # scheduler.load_state_dict(checkpoint['scheduler'])
        total_iter = checkpoint['epoch'] * len(train_loader)

    if not os.path.exists(os.path.join(opt.summary_dir, opt.exp_name)):
        os.mkdir(os.path.join(opt.summary_dir, opt.exp_name))
    summary_writer = SummaryWriter(
        log_dir=os.path.join(opt.summary_dir, opt.exp_name))

    print('Start training TextSnake.')

    for epoch in range(opt.start_epoch, opt.max_epoch):
        adjust_learning_rate(optimizer, total_iter)
        train(model, train_loader, criterion, None, optimizer, epoch,
              summary_writer)

    print('End.')
Example #12
0
def test_TextSnake(config_file):
    import sys
    sys.path.append('./detection_model/TextSnake_pytorch')

    import os
    import time
    import numpy as np
    import torch
    import json

    import torch.backends.cudnn as cudnn
    import torch.utils.data as data
    import torch.nn.functional as func

    from dataset.total_text import TotalText
    from network.textnet import TextNet
    from util.detection import TextDetector
    from util.augmentation import BaseTransform, EvalTransform
    from util.config import config as cfg, update_config, print_config
    from util.misc import to_device, fill_hole
    from util.option import BaseOptions
    from util.visualize import visualize_detection
    from Evaluation.Detval import detval

    import cv2

    from yacs.config import CfgNode as CN

    def read_config_file(config_file):
        """
        read config information form yaml file
        """
        f = open(config_file)
        opt = CN.load_cfg(f)
        return opt

    opt = read_config_file(config_file)

    def result2polygon(image, result):
        """
        convert geometric info(center_x, center_y, radii) into contours
        :param result: (list), each with (n, 3), 3 denotes (x, y, radii)
        :return: (np.ndarray list), polygon format contours
        """
        conts = []
        for instance in result:
            mask = np.zeros(image.shape[:2], dtype=np.uint8)
            for disk in instance:
                for x, y, r in disk:
                    cv2.circle(mask, (int(x), int(y)), int(r), (1), -1)

            cont, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL,
                                       cv2.CHAIN_APPROX_SIMPLE)
            if len(cont) > 0:
                # for item in cont:
                #     conts.append(item)
                conts.append(cont[0])

        conts = [cont[:, 0, :] for cont in conts]
        return conts

    def mask2conts(mask):
        conts, _ = cv2.findContours(mask, cv2.RETR_TREE,
                                    cv2.CHAIN_APPROX_SIMPLE)
        conts = [cont[:, 0, :] for cont in conts]
        return conts

    def rescale_result(image, polygons, H, W):
        ori_H, ori_W = image.shape[:2]
        image = cv2.resize(image, (W, H))
        for polygon in polygons:
            cont = polygon['points']
            cont[:, 0] = (cont[:, 0] * W / ori_W).astype(int)
            cont[:, 1] = (cont[:, 1] * H / ori_H).astype(int)
            polygon['points'] = cont
        return image, polygons

    def rescale_padding_result(image, polygons, ori_h, ori_w):
        h, w = image.shape[:2]
        # get no-padding image size
        resize_h = ori_h if ori_h % 32 == 0 else (ori_h // 32) * 32
        resize_w = ori_w if ori_w % 32 == 0 else (ori_w // 32) * 32
        ratio = float(h) / resize_h if resize_h > resize_w else float(
            w) / resize_w
        resize_h = int(resize_h * ratio)
        resize_w = int(resize_w * ratio)

        # crop no-padding image
        no_padding_image = image[0:resize_h, 0:resize_w, ::-1]
        no_padding_image = cv2.resize(no_padding_image, (ori_w, ori_h))

        # rescale points
        for polygon in polygons:
            polygon['points'][:, 0] = (polygon['points'][:, 0] * float(ori_w) /
                                       resize_w).astype(np.int32)
            polygon['points'][:, 1] = (polygon['points'][:, 1] * float(ori_h) /
                                       resize_h).astype(np.int32)

        return no_padding_image, polygons

    def calc_confidence(contours, score_map):
        polygons = []
        for cnt in contours:
            drawing = np.zeros(score_map.shape[1:], np.int8)
            mask = cv2.fillPoly(drawing, [cnt.astype(np.int32)], 1)
            area = np.sum(np.greater(mask, 0))
            if not area > 0:
                continue

            confidence = np.sum(mask * score_map[0]) / area

            polygon = {'points': cnt, 'confidence': confidence}

            polygons.append(polygon)

        return polygons

    def load_model(model, model_path):
        """
        load retrained model

        Args:
            model: the name to model
            model_path: the path to model

        """
        print('Loading from {}'.format(model_path))
        state_dict = torch.load(model_path)
        model.load_state_dict(state_dict['model'])

    def inference(model, detector, test_loader):
        """
        start inference with the parameters provided earlier

        """
        gt_json_path = os.path.join('/home/shf/fudan_ocr_system/datasets/',
                                    opt.dataset, 'train_labels.json')
        # gt_json_path = '/workspace/mnt/group/ocr/wangxunyan/maskscoring_rcnn/crop_train/crop_result_js.json'
        with open(gt_json_path, 'r') as f:
            gt_dict = json.load(f)

        model.eval()
        result = dict()

        for i, (img, train_mask, tr_mask, tcl_mask, radius_map, sin_map,
                cos_map, meta) in enumerate(test_loader):
            timer = {'model': 0, 'detect': 0, 'viz': 0, 'restore': 0}
            start = time.time()

            img, train_mask, tr_mask, tcl_mask, radius_map, sin_map, cos_map = to_device(
                img, train_mask, tr_mask, tcl_mask, radius_map, sin_map,
                cos_map)
            # inference
            output = model(img)
            if opt.multi_scale:
                size_h, size_w = img.shape[2:4]
                img_rescale = func.interpolate(img,
                                               scale_factor=0.5,
                                               mode='nearest')
                output_rescale = model(img_rescale)
                output_rescale = func.interpolate(output_rescale,
                                                  size=(size_h, size_w),
                                                  mode='nearest')

            timer['model'] = time.time() - start

            for idx in range(img.size(0)):
                start = time.time()
                print('detect {} / {} images: {}.'.format(
                    i, len(test_loader), meta['image_id'][idx]))
                tr_pred = output[idx, 0:2].softmax(dim=0).data.cpu().numpy()
                tcl_pred = output[idx, 2:4].softmax(dim=0).data.cpu().numpy()
                sin_pred = output[idx, 4].data.cpu().numpy()
                cos_pred = output[idx, 5].data.cpu().numpy()
                radii_pred = output[idx, 6].data.cpu().numpy()

                # tr_pred_mask = 1 / (1 + np.exp(-12*tr_pred[1]+3))
                tr_pred_mask = np.where(tr_pred[1] > detector.tr_conf_thresh,
                                        1, tr_pred[1])
                # tr_pred_mask = fill_hole(tr_pred_mask)

                tcl_pred_mask = (tcl_pred *
                                 tr_pred_mask)[1] > detector.tcl_conf_thresh

                if opt.multi_scale:
                    tr_pred_rescale = output_rescale[
                        idx, 0:2].sigmoid().data.cpu().numpy()
                    tcl_pred_rescale = output_rescale[idx, 2:4].softmax(
                        dim=0).data.cpu().numpy()

                    tr_pred_scale_mask = np.where(
                        tr_pred_rescale[1] + tr_pred[1] > 1, 1,
                        tr_pred_rescale[1] + tr_pred[1])
                    tr_pred_mask = tr_pred_scale_mask

                    # weighted adding
                    origin_ratio = 0.5
                    rescale_ratio = 0.5
                    tcl_pred = (tcl_pred * origin_ratio +
                                tcl_pred_rescale * rescale_ratio).astype(
                                    np.float32)
                    tcl_pred_mask = (
                        tcl_pred * tr_pred_mask)[1] > detector.tcl_conf_thresh

                batch_result = detector.complete_detect(
                    tr_pred_mask, tcl_pred_mask, sin_pred, cos_pred,
                    radii_pred)  # (n_tcl, 3)
                timer['detect'] = time.time() - start

                start = time.time()
                # visualization
                img_show = img[idx].permute(1, 2, 0).cpu().numpy()
                img_show = ((img_show * opt.stds + opt.means) * 255).astype(
                    np.uint8)
                H, W = meta['Height'][idx].item(), meta['Width'][idx].item()

                # get pred_contours
                contours = result2polygon(img_show, batch_result)

                if opt.viz:
                    resize_H = H if H % 32 == 0 else (H // 32) * 32
                    resize_W = W if W % 32 == 0 else (W // 32) * 32

                    ratio = float(
                        img_show.shape[0]
                    ) / resize_H if resize_H > resize_W else float(
                        img_show.shape[1]) / resize_W
                    resize_H = int(resize_H * ratio)
                    resize_W = int(resize_W * ratio)

                    gt_info = gt_dict[int(meta['image_id'][idx].lstrip(
                        'gt_').rstrip('.jpg').split('_')[1])]

                    gt_contours = []
                    # for gt in gt_info:
                    #     if not gt['illegibility']:
                    #         gt_cont = np.array(gt['points'])
                    #         gt_cont[:, 0] = (gt_cont[:, 0] * float(resize_W) / W).astype(np.int32)
                    #         gt_cont[:, 1] = (gt_cont[:, 1] * float(resize_H) / H).astype(np.int32)
                    #         gt_contours.append(gt_cont)
                    gt_cont = np.array(gt_info['points'])
                    gt_cont[:, 0] = gt_cont[:, 0] * float(resize_W) / float(W)
                    gt_cont[:, 1] = gt_cont[:, 1] * float(resize_H) / float(H)
                    gt_contours.append(gt_cont.astype(np.int32))
                    illegal_contours = mask2conts(
                        meta['illegal_mask'][idx].cpu().numpy())

                    predict_vis = visualize_detection(
                        img_show, tr_pred_mask, tcl_pred_mask.astype(np.uint8),
                        contours.copy())
                    gt_vis = visualize_detection(img_show,
                                                 tr_mask[idx].cpu().numpy(),
                                                 tcl_mask[idx].cpu().numpy(),
                                                 gt_contours, illegal_contours)
                    im_vis = np.concatenate([predict_vis, gt_vis], axis=0)
                    path = os.path.join(opt.vis_dir, meta['image_id'][idx])
                    cv2.imwrite(path, im_vis)
                timer['viz'] = time.time() - start

                start = time.time()
                polygons = calc_confidence(contours, tr_pred)
                img_show, polygons = rescale_padding_result(
                    img_show, polygons, H, W)

                # filter too small polygon
                for i, poly in enumerate(polygons):
                    if cv2.contourArea(poly['points']) < 100:
                        polygons[i] = []
                polygons = [item for item in polygons if item != []]

                # convert np.array to list
                for polygon in polygons:
                    polygon['points'] = polygon['points'].tolist()

                result[meta['image_id'][idx].replace('.jpg', '').replace(
                    'gt', 'res')] = polygons
                timer['restore'] = time.time() - start

            print(
                'Cost time {:.2f}s: model {:.2f}s, detect {:.2f}s, viz {:.2f}s, restore {:.2f}s'
                .format(
                    timer['model'] + timer['detect'] + timer['viz'] +
                    timer['restore'], timer['model'], timer['detect'],
                    timer['viz'], timer['restore']))

        # write to json file
        with open(os.path.join(opt.output_dir, 'result.json'), 'w') as f:
            json.dump(result, f)
            print("Output json file in {}.".format(opt.output_dir))

    torch.cuda.set_device(opt.num_device)
    option = BaseOptions(config_file)
    args = option.initialize()

    update_config(opt, args)
    print_config(opt)
    data_root = os.path.join(opt.data_root, opt.dataset)
    testset = TotalText(
        data_root=data_root,
        ignore_list=os.path.join(data_root, 'ignore_list.txt'),
        is_training=False,
        transform=EvalTransform(size=1280, mean=opt.means, std=opt.stds)
        # transform=BaseTransform(size=1280, mean=opt.means, std=opt.stds)
    )
    test_loader = data.DataLoader(testset,
                                  batch_size=1,
                                  shuffle=False,
                                  num_workers=opt.num_workers)

    # Model
    model = TextNet(backbone=opt.backbone, output_channel=7)
    model_path = os.path.join(opt.save_dir, opt.exp_name, \
                              'textsnake_{}_{}.pth'.format(model.backbone_name, opt.checkepoch))
    load_model(model, model_path)

    # copy to cuda
    model = model.to(opt.device)
    if opt.cuda:
        cudnn.benchmark = True
    detector = TextDetector(tcl_conf_thresh=0.3, tr_conf_thresh=1.0)  # 0.3

    # check vis_dir and output_dir exist
    if opt.viz:
        if not os.path.exists(opt.vis_dir):
            os.mkdir(opt.vis_dir)
    if not os.path.exists(opt.output_dir):
        os.mkdir(opt.output_dir)

    print('Start testing TextSnake.')

    inference(model, detector, test_loader)
    detval()

    print('End.')
Example #13
0
def main():

    global lr

    if cfg.dataset == 'total-text':

        trainset = TotalText(data_root='data/total-text',
                             ignore_list=None,
                             is_training=True,
                             transform=Augmentation(size=cfg.input_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))

        valset = TotalText(data_root='data/total-text',
                           ignore_list=None,
                           is_training=False,
                           transform=BaseTransform(size=cfg.input_size,
                                                   mean=cfg.means,
                                                   std=cfg.stds))

    elif cfg.dataset == 'synth-text':
        trainset = SynthText(data_root='data/SynthText',
                             is_training=True,
                             transform=Augmentation(size=cfg.input_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))
        valset = None
    else:
        pass

    train_loader = data.DataLoader(trainset,
                                   batch_size=cfg.batch_size,
                                   shuffle=True,
                                   num_workers=cfg.num_workers)
    # DataLoader时,设置pin_memory = True,则意味着生成的Tensor数据最开始是属于内存中的锁页内存,这样将内存的Tensor转义到GPU的显存就会更快一些。
    if valset:
        val_loader = data.DataLoader(valset,
                                     batch_size=cfg.batch_size,
                                     shuffle=False,
                                     num_workers=cfg.num_workers)
    else:
        valset = None

    log_dir = os.path.join(
        cfg.log_dir,
        datetime.now().strftime('%b%d_%H-%M-%S_') + cfg.exp_name)
    logger = LogSummary(log_dir)

    # Model
    # 载入模型,backbone默认为vgg
    model = TextNet(is_training=True, backbone=cfg.net)
    if cfg.mgpu:
        model = nn.DataParallel(model)

    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True
    # 载入继续训练
    if cfg.resume:
        load_model(model, cfg.resume)

    # loss定义主要有loss_tr, loss_tcl, loss_radii, loss_sin, loss_cos
    criterion = TextLoss()
    lr = cfg.lr
    # Adam梯度优化
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)

    if cfg.dataset == 'synth-text':
        # 固定学习率。当last_epoch = -1时,将初始lr设置为lr。
        scheduler = FixLR(optimizer)
    else:
        # 动态学习率.每step_size步, lr*0.1
        scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)

    print('Start training TextSnake.')
    for epoch in range(cfg.start_epoch, cfg.max_epoch):
        train(model, train_loader, criterion, scheduler, optimizer, epoch,
              logger)
        if valset:
            validation(model, val_loader, criterion, epoch, logger)

    print('End.')
Example #14
0
                               num_workers=cfg.num_workers)

if valset:
    val_loader = data.DataLoader(valset,
                                 batch_size=cfg.batch_size,
                                 shuffle=False,
                                 num_workers=cfg.num_workers)
else:
    valset = None

log_dir = os.path.join(
    cfg.log_dir,
    datetime.now().strftime('%b%d_%H-%M-%S_') + cfg.exp_name)
logger = LogSummary(log_dir)

model = TextNet(is_training=True, backbone=cfg.net)
if cfg.mgpu:
    model = nn.DataParallel(model)

model = model.to(cfg.device)

if cfg.cuda:
    cudnn.benchmark = True

if cfg.resume:
    load_model(model, cfg.resume)

criterion = TextLoss()
lr = cfg.lr
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
Example #15
0
class text_detection(object):
    def __init__(self):
        self.switch = False
        r = rospkg.RosPack()
        self.path = r.get_path('textsnake')
        self.commodity_list = []
        self.read_commodity(
            r.get_path('text_msgs') + "/config/commodity_list.txt")
        self.prob_threshold = 0.90
        self.cv_bridge = CvBridge()
        self.means = (0.485, 0.456, 0.406)
        self.stds = (0.229, 0.224, 0.225)

        self.saver = False

        self.color_map = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0),
                          (255, 255, 255)]  # 0 90 180 270 noise

        self.objects = []
        self.network = TextNet(is_training=False, backbone='vgg')
        self.is_compressed = False

        self.cuda_use = torch.cuda.is_available()

        if self.cuda_use:
            self.network = self.network.cuda()

        model_name = "textsnake.pth"
        self.network.load_model(os.path.join(self.path, "weights/",
                                             model_name))

        self.detector = TextDetector(self.network,
                                     tr_thresh=0.6,
                                     tcl_thresh=0.4)
        self.network.eval()
        #### Publisher
        self.image_pub = rospy.Publisher("~predict_img", Image, queue_size=1)
        self.img_bbox_pub = rospy.Publisher("~predict_bbox",
                                            Image,
                                            queue_size=1)
        self.predict_img_pub = rospy.Publisher("/prediction_img",
                                               Image,
                                               queue_size=1)
        self.predict_mask_pub = rospy.Publisher("/prediction_mask",
                                                Image,
                                                queue_size=1)
        self.text_detection_pub = rospy.Publisher("/text_detection_array",
                                                  text_detection_array,
                                                  queue_size=1)
        ### service
        self.predict_switch_ser = rospy.Service("~predict_switch_server",
                                                predict_switch,
                                                self.switch_callback)
        self.predict_ser = rospy.Service("~text_detection", text_detection_srv,
                                         self.srv_callback)
        ### msg filter
        image_sub = message_filters.Subscriber('/camera/color/image_raw',
                                               Image)
        depth_sub = message_filters.Subscriber(
            '/camera/aligned_depth_to_color/image_raw', Image)
        ts = message_filters.TimeSynchronizer([image_sub, depth_sub], 10)
        ts.registerCallback(self.callback)
        self.saver_count = 0
        if self.saver:
            self.p_img = os.path.join(self.path, "saver", "img")
            if not os.path.exists(self.p_img):
                os.makedirs(self.p_img)
            self.p_depth = os.path.join(self.path, "saver", "depth")
            if not os.path.exists(self.p_depth):
                os.makedirs(self.p_depth)
            self.p_mask = os.path.join(self.path, "saver", "mask")
            if not os.path.exists(self.p_mask):
                os.makedirs(self.p_mask)
            self.p_result = os.path.join(self.path, "saver", "result")
            if not os.path.exists(self.p_result):
                os.makedirs(self.p_result)

        print "============ Ready ============"
        print "TextSnake Model Parameters number: " + str(
            self.count_parameters(self.network))

    def read_commodity(self, path):

        for line in open(path, "r"):
            line = line.rstrip('\n')
            self.commodity_list.append(line)
        print "Node (text_detection): Finish reading list"

    def count_parameters(self, model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    def srv_callback(self, req):
        text_array = text_detection_array()

        resp = text_detection_srvResponse()
        img_msg = rospy.wait_for_message('/camera/color/image_raw',
                                         Image,
                                         timeout=None)
        resp.depth = rospy.wait_for_message(
            '/camera/aligned_depth_to_color/image_raw', Image, timeout=None)
        resp.image = img_msg
        try:
            if self.is_compressed:
                np_arr = np.fromstring(img_msg.data, np.uint8)
                cv_image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
            else:
                cv_image = self.cv_bridge.imgmsg_to_cv2(img_msg, "bgr8")
        except CvBridgeError as e:
            resp.status = e
            print(e)
        (rows, cols, channels) = cv_image.shape
        rows = int(np.ceil(rows / 32.) * 32)
        cols = int(np.ceil(cols / 32.) * 32)
        cv_image1 = np.zeros((rows, cols, channels), dtype=np.uint8)
        cv_image1[:cv_image.shape[0], :cv_image.shape[1], :cv_image.
                  shape[2]] = cv_image[:, :, :]
        cv_image = cv_image1.copy()

        mask = np.zeros([cv_image.shape[0], cv_image.shape[1]], dtype=np.uint8)
        img_list_0_90_180_270 = rotate_cv(cv_image)

        for i in range(4):

            predict_img, contours = self.predict(img_list_0_90_180_270[i])
            img_bbox = img_list_0_90_180_270[i].copy()

            text_array = text_detection_array()
            text_array.image = self.cv_bridge.cv2_to_imgmsg(
                img_list_0_90_180_270[i], "bgr8")
            text_array.depth = resp.depth
            for _cont in contours:
                text_bb = text_detection_msg()
                for p in _cont:
                    int_array = int_arr()
                    int_array.point.append(p[0])
                    int_array.point.append(p[1])
                    text_bb.contour.append(int_array)
                cv2.drawContours(predict_img, [_cont], -1, self.color_map[i],
                                 3)
                text_bb.box.xmin = min(_cont[:, 0])
                text_bb.box.xmax = max(_cont[:, 0])
                text_bb.box.ymin = min(_cont[:, 1])
                text_bb.box.ymax = max(_cont[:, 1])
                text_array.text_array.append(text_bb)
                cv2.rectangle(img_bbox, (text_bb.box.xmin, text_bb.box.ymin),
                              (text_bb.box.xmax, text_bb.box.ymax),
                              self.color_map[i], 3)
            text_array.bb_count = len(text_array.text_array)
            # self.text_detection_pub.publish(text_array)

            recog_req = text_recognize_srvRequest()
            recog_resp = text_recognize_srvResponse()
            try:
                rospy.wait_for_service(RECOG_SRV, timeout=10)
                recog_req.data = text_array
                recog_req.direct = i
                recognition_srv = rospy.ServiceProxy(RECOG_SRV,
                                                     text_recognize_srv)
                recog_resp = recognition_srv(recog_req)
            except (rospy.ServiceException, rospy.ROSException), e:
                resp.state = e

            recog_mask = self.cv_bridge.imgmsg_to_cv2(recog_resp.mask, "8UC1")

            if i == 0:
                pass
            elif i == 1:
                recog_mask = rotate_back_change_h_w(recog_mask, angle=-90)
                predict_img = rotate_back_change_h_w(predict_img, angle=-90)
                img_bbox = rotate_back_change_h_w(img_bbox, angle=-90)
            elif i == 2:
                recog_mask = rotate_back(recog_mask, angle=-180)
                predict_img = rotate_back(predict_img, angle=-180)
                img_bbox = rotate_back(img_bbox, angle=-180)
            else:
                recog_mask = rotate_back_change_h_w(recog_mask, angle=-270)
                predict_img = rotate_back_change_h_w(predict_img, angle=-270)
                img_bbox = rotate_back_change_h_w(img_bbox, angle=-270)

            mask[recog_mask != 0] = recog_mask[recog_mask != 0]

            try:
                self.image_pub.publish(
                    self.cv_bridge.cv2_to_imgmsg(predict_img, "bgr8"))
                self.img_bbox_pub.publish(
                    self.cv_bridge.cv2_to_imgmsg(img_bbox, "bgr8"))
            except CvBridgeError as e:
                resp.state = e
                print(e)

        ## publish visualization
        self.img_show(mask, cv_image)
        resp.mask = self.cv_bridge.cv2_to_imgmsg(mask, "8UC1")
        vis_mask = np.zeros([cv_image.shape[0], cv_image.shape[1]],
                            dtype=np.uint8)
        vis_mask[mask != 0] = 255 - mask[mask != 0]
        if self.saver:
            self.save_func(cv_image1, vis_mask,
                           self.cv_bridge.imgmsg_to_cv2(resp.depth, "16UC1"),
                           cv_image)
        ## srv end
        self.predict_img_pub.publish(
            self.cv_bridge.cv2_to_imgmsg(cv_image, "bgr8"))
        self.predict_mask_pub.publish(
            self.cv_bridge.cv2_to_imgmsg(vis_mask, "8UC1"))
        return resp
Example #16
0
    def __init__(self):
        self.switch = False
        r = rospkg.RosPack()
        self.path = r.get_path('textsnake')
        self.commodity_list = []
        self.read_commodity(
            r.get_path('text_msgs') + "/config/commodity_list.txt")
        self.prob_threshold = 0.90
        self.cv_bridge = CvBridge()
        self.means = (0.485, 0.456, 0.406)
        self.stds = (0.229, 0.224, 0.225)

        self.saver = False

        self.color_map = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0),
                          (255, 255, 255)]  # 0 90 180 270 noise

        self.objects = []
        self.network = TextNet(is_training=False, backbone='vgg')
        self.is_compressed = False

        self.cuda_use = torch.cuda.is_available()

        if self.cuda_use:
            self.network = self.network.cuda()

        model_name = "textsnake.pth"
        self.network.load_model(os.path.join(self.path, "weights/",
                                             model_name))

        self.detector = TextDetector(self.network,
                                     tr_thresh=0.6,
                                     tcl_thresh=0.4)
        self.network.eval()
        #### Publisher
        self.image_pub = rospy.Publisher("~predict_img", Image, queue_size=1)
        self.img_bbox_pub = rospy.Publisher("~predict_bbox",
                                            Image,
                                            queue_size=1)
        self.predict_img_pub = rospy.Publisher("/prediction_img",
                                               Image,
                                               queue_size=1)
        self.predict_mask_pub = rospy.Publisher("/prediction_mask",
                                                Image,
                                                queue_size=1)
        self.text_detection_pub = rospy.Publisher("/text_detection_array",
                                                  text_detection_array,
                                                  queue_size=1)
        ### service
        self.predict_switch_ser = rospy.Service("~predict_switch_server",
                                                predict_switch,
                                                self.switch_callback)
        self.predict_ser = rospy.Service("~text_detection", text_detection_srv,
                                         self.srv_callback)
        ### msg filter
        image_sub = message_filters.Subscriber('/camera/color/image_raw',
                                               Image)
        depth_sub = message_filters.Subscriber(
            '/camera/aligned_depth_to_color/image_raw', Image)
        ts = message_filters.TimeSynchronizer([image_sub, depth_sub], 10)
        ts.registerCallback(self.callback)
        self.saver_count = 0
        if self.saver:
            self.p_img = os.path.join(self.path, "saver", "img")
            if not os.path.exists(self.p_img):
                os.makedirs(self.p_img)
            self.p_depth = os.path.join(self.path, "saver", "depth")
            if not os.path.exists(self.p_depth):
                os.makedirs(self.p_depth)
            self.p_mask = os.path.join(self.path, "saver", "mask")
            if not os.path.exists(self.p_mask):
                os.makedirs(self.p_mask)
            self.p_result = os.path.join(self.path, "saver", "result")
            if not os.path.exists(self.p_result):
                os.makedirs(self.p_result)

        print "============ Ready ============"
        print "TextSnake Model Parameters number: " + str(
            self.count_parameters(self.network))
Example #17
0
def main():

    # global lr

    trainset = CustomText(data_root='data/{}'.format(cfg.dataset),
                          is_training=True,
                          transform=Augmentation(size=cfg.input_size,
                                                 mean=cfg.means,
                                                 std=cfg.stds))

    valset = CustomText(data_root='data/{}'.format(cfg.dataset),
                        is_training=False,
                        transform=BaseTransform(size=cfg.input_size,
                                                mean=cfg.means,
                                                std=cfg.stds))

    train_loader = data.DataLoader(trainset,
                                   batch_size=cfg.batch_size,
                                   shuffle=True,
                                   num_workers=cfg.num_workers,
                                   collate_fn=my_collate)
    if valset:
        val_loader = data.DataLoader(valset,
                                     batch_size=cfg.batch_size,
                                     shuffle=False,
                                     num_workers=cfg.num_workers,
                                     collate_fn=my_collate)
    else:
        valset = None

    log_dir = os.path.join(
        cfg.log_dir,
        datetime.now().strftime('%b%d_%H-%M-%S_') + cfg.exp_name)
    logger = LogSummary(log_dir)

    # Model
    model = TextNet(is_training=True, backbone=cfg.net)
    if cfg.mgpu:
        model = nn.DataParallel(model)

    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True

    if cfg.resume:
        load_model(model, cfg.resume)

    criterion = TextLoss()
    # lr = cfg.lr
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)

    if cfg.dataset == 'synth-text':
        scheduler = FixLR(optimizer)
    else:
        scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)

    print('Start training TextSnake.')
    writer = SummaryWriter(logdir='process')
    for epoch in range(cfg.start_epoch, cfg.max_epoch):

        train(model, train_loader, criterion, scheduler, optimizer, epoch,
              logger)
        if valset:
            validation_loss = validation(model, val_loader, criterion, epoch,
                                         logger)

            writer.add_scalar('Validation Loss', validation_loss, epoch)

    print('End.')
Example #18
0
    def Train(self):
        cfg.max_epoch = self.system_dict["params"]["max_epoch"]
        cfg.means = self.system_dict["params"]["means"]
        cfg.stds = self.system_dict["params"]["stds"]
        cfg.log_dir = self.system_dict["params"]["log_dir"]
        cfg.exp_name = self.system_dict["params"]["exp_name"]
        cfg.net = self.system_dict["params"]["net"]
        cfg.resume = self.system_dict["params"]["resume"]
        cfg.cuda = self.system_dict["params"]["cuda"]
        cfg.mgpu = self.system_dict["params"]["mgpu"]
        cfg.save_dir = self.system_dict["params"]["save_dir"]
        cfg.vis_dir = self.system_dict["params"]["vis_dir"]
        cfg.input_channel = self.system_dict["params"]["input_channel"]

        cfg.lr = self.system_dict["params"]["lr"]
        cfg.weight_decay = self.system_dict["params"]["weight_decay"]
        cfg.gamma = self.system_dict["params"]["gamma"]
        cfg.momentum = self.system_dict["params"]["momentum"]
        cfg.optim = self.system_dict["params"]["optim"]
        cfg.display_freq = self.system_dict["params"]["display_freq"]
        cfg.viz_freq = self.system_dict["params"]["viz_freq"]
        cfg.save_freq = self.system_dict["params"]["save_freq"]
        cfg.log_freq = self.system_dict["params"]["log_freq"]

        cfg.batch_size = self.system_dict["params"]["batch_size"]
        cfg.rescale = self.system_dict["params"]["rescale"]
        cfg.checkepoch = self.system_dict["params"]["checkepoch"]

        cfg.val_freq = 1000
        cfg.start_iter = 0
        cfg.loss = "CrossEntropyLoss"
        cfg.pretrain = False
        cfg.verbose = True
        cfg.viz = True
        cfg.lr_adjust = "step"  #fix, step
        cfg.stepvalues = []
        cfg.step_size = cfg.max_epoch // 2

        if cfg.cuda and torch.cuda.is_available():
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
            cudnn.benchmark = True
            cfg.device = torch.device("cuda")
        else:
            torch.set_default_tensor_type('torch.FloatTensor')
            cfg.device = torch.device("cpu")

        # Create weights saving directory
        if not os.path.exists(cfg.save_dir):
            os.mkdir(cfg.save_dir)

        # Create weights saving directory of target model
        model_save_path = os.path.join(cfg.save_dir, cfg.exp_name)

        if not os.path.exists(model_save_path):
            os.mkdir(model_save_path)

        if not os.path.exists(cfg.vis_dir):
            os.mkdir(cfg.vis_dir)

        if (self.system_dict["params"]["annotation_type"] == "text"):
            trainset = TotalText_txt(
                self.system_dict["params"]["train_img_folder"],
                self.system_dict["params"]["train_anno_folder"],
                ignore_list=None,
                is_training=True,
                transform=Augmentation(size=cfg.input_size,
                                       mean=cfg.means,
                                       std=cfg.stds))
            train_loader = data.DataLoader(trainset,
                                           batch_size=cfg.batch_size,
                                           shuffle=True,
                                           num_workers=cfg.num_workers)

            if (self.system_dict["params"]["val_dataset"]):
                valset = TotalText_txt(
                    self.system_dict["params"]["val_img_folder"],
                    self.system_dict["params"]["val_anno_folder"],
                    ignore_list=None,
                    is_training=False,
                    transform=BaseTransform(size=cfg.input_size,
                                            mean=cfg.means,
                                            std=cfg.stds))
                val_loader = data.DataLoader(valset,
                                             batch_size=cfg.batch_size,
                                             shuffle=False,
                                             num_workers=cfg.num_workers)
            else:
                valset = None

        elif (self.system_dict["params"]["annotation_type"] == "mat"):
            trainset = TotalText_mat(
                self.system_dict["params"]["train_img_folder"],
                self.system_dict["params"]["train_anno_folder"],
                ignore_list=None,
                is_training=True,
                transform=Augmentation(size=cfg.input_size,
                                       mean=cfg.means,
                                       std=cfg.stds))
            train_loader = data.DataLoader(trainset,
                                           batch_size=cfg.batch_size,
                                           shuffle=True,
                                           num_workers=cfg.num_workers)

            if (self.system_dict["params"]["val_dataset"]):
                valset = TotalText_mat(
                    self.system_dict["params"]["val_img_folder"],
                    self.system_dict["params"]["val_anno_folder"],
                    ignore_list=None,
                    is_training=False,
                    transform=BaseTransform(size=cfg.input_size,
                                            mean=cfg.means,
                                            std=cfg.stds))
                val_loader = data.DataLoader(valset,
                                             batch_size=cfg.batch_size,
                                             shuffle=False,
                                             num_workers=cfg.num_workers)
            else:
                valset = None

        log_dir = os.path.join(
            cfg.log_dir,
            datetime.now().strftime('%b%d_%H-%M-%S_') + cfg.exp_name)
        logger = LogSummary(log_dir)

        model = TextNet(is_training=True, backbone=cfg.net)
        if cfg.mgpu:
            model = nn.DataParallel(model)

        model = model.to(cfg.device)

        if cfg.cuda:
            cudnn.benchmark = True

        if cfg.resume:
            self.load_model(model, cfg.resume)

        criterion = TextLoss()
        lr = cfg.lr
        if (cfg.optim == "adam"):
            optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
        else:
            optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)

        scheduler = lr_scheduler.StepLR(optimizer,
                                        step_size=cfg.step_size,
                                        gamma=cfg.gamma)

        train_step = 0
        print('Start training TextSnake.')
        for epoch in range(cfg.start_epoch, cfg.max_epoch):
            train_step = self.train(model, train_loader, criterion, scheduler,
                                    optimizer, epoch, logger, train_step)
            if valset:
                self.validation(model, val_loader, criterion, epoch, logger)
            self.save_model(model, "final", scheduler.get_lr(), optimizer)