Example #1
0
def main(args):
    if args.checkpoint == '':
        args.checkpoint = "checkpoints/ic15_%s_bs_%d_ep_%d" % (
            args.arch, args.batch_size, args.n_epoch)
    if args.pretrain:
        if 'synth' in args.pretrain:
            args.checkpoint += "_pretrain_synth"
        else:
            args.checkpoint += "_pretrain_ic17"

    print('checkpoint path: %s' % args.checkpoint)
    print('init lr: %.8f' % args.lr)
    print('schedule: ', args.schedule)
    sys.stdout.flush()

    if not os.path.isdir(args.checkpoint):
        os.makedirs(args.checkpoint)

    kernel_num = 1
    min_scale = 0.4
    start_epoch = 0

    data_loader = IC15Loader(root_dir=args.root_dir,
                             is_transform=True,
                             img_size=args.img_size,
                             kernel_num=kernel_num,
                             min_scale=min_scale)
    train_loader = torch.utils.data.DataLoader(data_loader,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=3,
                                               drop_last=True,
                                               pin_memory=True)

    if args.arch == "resnet50":
        model = models.resnet50(pretrained=True, num_classes=kernel_num)
    elif args.arch == "resnet101":
        model = models.resnet101(pretrained=True, num_classes=kernel_num)
    elif args.arch == "resnet152":
        model = models.resnet152(pretrained=True, num_classes=kernel_num)
    elif args.arch == "resnet18":
        model = models.resnet18(pretrained=True, num_classes=kernel_num)
    model = torch.nn.DataParallel(model).cuda()

    if hasattr(model.module, 'optimizer'):
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=0.99,
                                    weight_decay=5e-4)

    title = 'icdar2015'
    if args.pretrain:
        print('Using pretrained model.')
        assert os.path.isfile(
            args.pretrain), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(args.pretrain)
        model.load_state_dict(checkpoint['state_dict'])
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(
            ['Learning Rate', 'Train Loss', 'Train Acc.', 'Train IOU.'])
    elif args.resume:
        print('Resuming from checkpoint.')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(args.resume)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'),
                        title=title,
                        resume=True)
    else:
        print('Training from scratch.')
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(
            ['Learning Rate', 'Train Loss', 'Train Acc.', 'Train IOU.'])

    writer = SummaryWriter(logdir=args.checkpoint)
    for epoch in range(start_epoch, args.n_epoch):
        adjust_learning_rate(args, optimizer, epoch)
        print('\nEpoch: [%d | %d] LR: %f' %
              (epoch + 1, args.n_epoch, optimizer.param_groups[0]['lr']))

        train_loss, train_te_acc, train_te_iou = train(train_loader, model,
                                                       dice_loss, optimizer,
                                                       epoch, writer)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'lr': args.lr,
                'optimizer': optimizer.state_dict(),
            },
            checkpoint=args.checkpoint)
Example #2
0
def main(args):
    if args.checkpoint == '':
        # args.checkpoint = "checkpointsfuns/funs19_%s_bs_%d_ep_%d"%(args.arch, args.batch_size, args.n_epoch)
        args.checkpoint = "checkpoints/model_funs_pretrain_ic15_frozen_dense_layers"
    if args.pretrain:
        if 'synth' in args.pretrain:
            args.checkpoint += "_pretrain_synth"
        else:
            args.checkpoint += "_pretrain_ic17"

    print('checkpoint path: %s' % args.checkpoint)
    print('init lr: %.8f' % args.lr)
    print('schedule: ', args.schedule)
    sys.stdout.flush()

    if not os.path.isdir(args.checkpoint):
        os.makedirs(args.checkpoint)

    kernel_num = 7
    min_scale = 0.4
    start_epoch = 0

    data_loader = IC15Loader(is_transform=True,
                             img_size=args.img_size,
                             kernel_num=kernel_num,
                             min_scale=min_scale)
    train_loader = torch.utils.data.DataLoader(data_loader,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=3,
                                               drop_last=True,
                                               pin_memory=True)

    if args.arch == "resnet50":
        model = models.resnet50(pretrained=True, num_classes=kernel_num)
    elif args.arch == "resnet101":
        model = models.resnet101(pretrained=True, num_classes=kernel_num)
    elif args.arch == "resnet152":
        model = models.resnet152(pretrained=True, num_classes=kernel_num)
    elif args.arch == "pvanet":
        model = models.pvanet(inputsize=args.img_size, num_classes=kernel_num)

    model = torch.nn.DataParallel(model).cuda()

    if hasattr(model.module, 'optimizer'):
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=0.99,
                                    weight_decay=5e-4)

    title = 'icdar2015'
    if args.pretrain:
        print('Using pretrained model.')
        assert os.path.isfile(
            args.pretrain), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(args.pretrain)
        model.load_state_dict(checkpoint['state_dict'])
        # fine tune output layers
        # grad = [
        #     'module.conv2.weight'
        #     'module.conv2.bias',
        #     'module.bn2.weight',
        #     'module.bn2.bias',
        #     'module.conv3.weight',
        #     'module.conv3.bias'
        # ]
        # for name,value in model.named_parameters():
        #     if name in grad:
        #         value.requires_grad = True
        #     else:
        #         value.requires_grad = False
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(
            ['Learning Rate', 'Train Loss', 'Train Acc.', 'Train IOU.'])
    elif args.resume:
        print('Resuming from checkpoint.')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(args.resume)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'),
                        title=title,
                        resume=True)
    else:
        print('Training from scratch.')
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(
            ['Learning Rate', 'Train Loss', 'Train Acc.', 'Train IOU.'])

    writer = SummaryWriter(args.summary_path)

    for epoch in range(start_epoch, args.n_epoch):
        adjust_learning_rate(args, optimizer, epoch)
        print('\nEpoch: [%d | %d] LR: %f' %
              (epoch + 1, args.n_epoch, optimizer.param_groups[0]['lr']))

        train_loss, train_te_acc, train_ke_acc, train_te_iou, train_ke_iou = train(
            train_loader, model, dice_loss, optimizer, epoch)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'lr': args.lr,
                'optimizer': optimizer.state_dict(),
            },
            checkpoint=args.checkpoint)

        writer.add_scalar('Loss', train_loss, epoch)
        writer.add_scalar('train_te_acc', train_te_acc, epoch)
        writer.add_scalar('train_te_iou', train_te_iou, epoch)
        writer.flush()

        logger.append([
            optimizer.param_groups[0]['lr'], train_loss, train_te_acc,
            train_te_iou
        ])
    logger.close()
Example #3
0
def main(args):
    if args.checkpoint == '':
        args.checkpoint = "checkpoints/ic15_%s_bs_%d_ep_%d"%(args.arch, args.batch_size, args.n_epoch)
    if args.pretrain:
        if 'synth' in args.pretrain:
            args.checkpoint += "_pretrain_synth"
        else:
            args.checkpoint += "_pretrain_s1280"

    print(('checkpoint path: %s'%args.checkpoint))
    print(('init lr: %.8f'%args.lr))
    print(('schedule: ', args.schedule))
    sys.stdout.flush()

    if not os.path.isdir(args.checkpoint):
        os.makedirs(args.checkpoint)

    writer=SummaryWriter(args.checkpoint)

    kernel_num=18
    start_epoch = 0
    #####
    #
    #
    #
    #####
    data_loader = IC15Loader(is_transform=True, img_size=args.img_size)
    train_loader = torch.utils.data.DataLoader(
        data_loader,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=3,
        drop_last=False,
        pin_memory=True)

    if args.arch == "resnet50":
        model = models.resnet50(pretrained=True, num_classes=kernel_num)
    elif args.arch == "resnet101":
        model = models.resnet101(pretrained=True, num_classes=kernel_num)
    elif args.arch == "resnet152":
        model = models.resnet152(pretrained=True, num_classes=kernel_num)
    elif args.arch == "vgg16":
        model = models.vgg16(pretrained=False,num_classes=kernel_num)
    
    model = torch.nn.DataParallel(model).cuda()
    model.train()

    if hasattr(model.module, 'optimizer'):
        optimizer = model.module.optimizer
    else:
        # NOTE 这个地方的momentum对训练影响相当之大,使用0.99时训练crossentropy无法收敛.
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)

    title = 'icdar2015'
    if args.pretrain:
        print('Using pretrained model.')
        assert os.path.isfile(args.pretrain), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(args.pretrain)
        model.load_state_dict(checkpoint['state_dict'])
        start_epoch = checkpoint['epoch']
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(['Learning Rate', 'Train Loss','Train Acc.', 'Train IOU.'])
    elif args.resume:
        print('Resuming from checkpoint.')
        assert os.path.isfile(args.resume), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(args.resume)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        # optimizer.load_state_dict(checkpoint['optimizer'])
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title, resume=True)
    else:
        print('Training from scratch.')
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(['Learning Rate', 'Train Loss','Train Acc.', 'Train IOU.'])
    images_loss = {}
    # data_plot = images_loss.values()
    # import matplotlib.pyplot as plt
    # plt.plot(data_plot)
    # plt.ylabel('Loss plot')
    # plt.show()
    for epoch in range(start_epoch, args.n_epoch):
        adjust_learning_rate(args, optimizer, epoch)
        print(('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.n_epoch, optimizer.param_groups[0]['lr'])))
        
        train_loss, train_te_acc, train_te_iou = train(train_loader,images_loss, model, dice_loss, optimizer, epoch,writer)

        if epoch %5 == 0 and epoch != 0:
            save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'lr': args.lr,
                    'optimizer' : optimizer.state_dict(),
                }, checkpoint=args.checkpoint,filename='checkpoint_%d.pth'%epoch)

        logger.append([optimizer.param_groups[0]['lr'], train_loss, train_te_acc, train_te_iou])
    logger.close()
    writer.flush()
    writer.close()
Example #4
0
def train_psenet(config_file):
    import sys
    sys.path.append('./detection_model/PSENet')
    # sys.path.append('/home/cjy/PSENet-master')

    import torch
    import argparse
    import numpy as np
    import torch.nn as nn
    import torch.nn.functional as F
    import shutil

    from torch.autograd import Variable
    from torch.utils import data
    import os

    from dataset import IC15Loader
    from metrics import runningScore
    import models
    from util import Logger, AverageMeter
    import time
    from tensorboardX import SummaryWriter
    import util
    from yacs.config import CfgNode as CN

    writer = SummaryWriter()

    def read_config_file(config_file):
        # 用yaml重构配置文件
        f = open(config_file)
        opt = CN.load_cfg(f)
        return opt

    args = read_config_file(config_file)

    def ohem_single(score, gt_text, training_mask):
        pos_num = (int)(np.sum(gt_text > 0.5)) - (int)(
            np.sum((gt_text > 0.5) & (training_mask <= 0.5)))

        if pos_num == 0:
            # selected_mask = gt_text.copy() * 0 # may be not good
            selected_mask = training_mask
            selected_mask = selected_mask.reshape(
                1, selected_mask.shape[0],
                selected_mask.shape[1]).astype('float32')
            return selected_mask

        neg_num = (int)(np.sum(gt_text <= 0.5))
        neg_num = (int)(min(pos_num * 3, neg_num))

        if neg_num == 0:
            selected_mask = training_mask
            selected_mask = selected_mask.reshape(
                1, selected_mask.shape[0],
                selected_mask.shape[1]).astype('float32')
            return selected_mask

        neg_score = score[gt_text <= 0.5]
        neg_score_sorted = np.sort(-neg_score)
        threshold = -neg_score_sorted[neg_num - 1]

        selected_mask = ((score >= threshold) |
                         (gt_text > 0.5)) & (training_mask > 0.5)
        selected_mask = selected_mask.reshape(
            1, selected_mask.shape[0],
            selected_mask.shape[1]).astype('float32')
        return selected_mask

    def ohem_batch(scores, gt_texts, training_masks):
        scores = scores.data.cpu().numpy()
        gt_texts = gt_texts.data.cpu().numpy()
        training_masks = training_masks.data.cpu().numpy()

        selected_masks = []
        for i in range(scores.shape[0]):
            selected_masks.append(
                ohem_single(scores[i, :, :], gt_texts[i, :, :],
                            training_masks[i, :, :]))

        selected_masks = np.concatenate(selected_masks, 0)
        selected_masks = torch.from_numpy(selected_masks).float()

        return selected_masks

    def dice_loss(input, target, mask):
        input = torch.sigmoid(input)

        input = input.contiguous().view(input.size()[0], -1)
        target = target.contiguous().view(target.size()[0], -1)
        mask = mask.contiguous().view(mask.size()[0], -1)

        input = input * mask
        target = target * mask

        a = torch.sum(input * target, 1)
        b = torch.sum(input * input, 1) + 0.001
        c = torch.sum(target * target, 1) + 0.001
        d = (2 * a) / (b + c)
        dice_loss = torch.mean(d)
        return 1 - dice_loss

    def cal_text_score(texts, gt_texts, training_masks, running_metric_text):
        training_masks = training_masks.data.cpu().numpy()
        pred_text = torch.sigmoid(texts).data.cpu().numpy() * training_masks
        pred_text[pred_text <= 0.5] = 0
        pred_text[pred_text > 0.5] = 1
        pred_text = pred_text.astype(np.int32)
        gt_text = gt_texts.data.cpu().numpy() * training_masks
        gt_text = gt_text.astype(np.int32)
        running_metric_text.update(gt_text, pred_text)
        score_text, _ = running_metric_text.get_scores()
        return score_text

    def cal_kernel_score(kernels, gt_kernels, gt_texts, training_masks,
                         running_metric_kernel):
        mask = (gt_texts * training_masks).data.cpu().numpy()
        kernel = kernels[:, -1, :, :]
        gt_kernel = gt_kernels[:, -1, :, :]
        pred_kernel = torch.sigmoid(kernel).data.cpu().numpy()
        pred_kernel[pred_kernel <= 0.5] = 0
        pred_kernel[pred_kernel > 0.5] = 1
        pred_kernel = (pred_kernel * mask).astype(np.int32)
        gt_kernel = gt_kernel.data.cpu().numpy()
        gt_kernel = (gt_kernel * mask).astype(np.int32)
        running_metric_kernel.update(gt_kernel, pred_kernel)
        score_kernel, _ = running_metric_kernel.get_scores()
        return score_kernel

    def train(train_loader, model, criterion, optimizer, epoch):
        model.train()

        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        running_metric_text = runningScore(2)
        running_metric_kernel = runningScore(2)

        end = time.time()
        for batch_idx, (imgs, gt_texts, gt_kernels,
                        training_masks) in enumerate(train_loader):
            data_time.update(time.time() - end)

            imgs = Variable(imgs.cuda())
            gt_texts = Variable(gt_texts.cuda())
            gt_kernels = Variable(gt_kernels.cuda())
            training_masks = Variable(training_masks.cuda())

            outputs = model(imgs)
            texts = outputs[:, 0, :, :]
            kernels = outputs[:, 1:, :, :]

            selected_masks = ohem_batch(texts, gt_texts, training_masks)
            selected_masks = Variable(selected_masks.cuda())

            loss_text = criterion(texts, gt_texts, selected_masks)

            loss_kernels = []
            mask0 = torch.sigmoid(texts).data.cpu().numpy()
            mask1 = training_masks.data.cpu().numpy()
            selected_masks = ((mask0 > 0.5) & (mask1 > 0.5)).astype('float32')
            selected_masks = torch.from_numpy(selected_masks).float()
            selected_masks = Variable(selected_masks.cuda())
            for i in range(6):
                kernel_i = kernels[:, i, :, :]
                gt_kernel_i = gt_kernels[:, i, :, :]
                loss_kernel_i = criterion(kernel_i, gt_kernel_i,
                                          selected_masks)
                loss_kernels.append(loss_kernel_i)
            loss_kernel = sum(loss_kernels) / len(loss_kernels)

            loss = 0.7 * loss_text + 0.3 * loss_kernel
            losses.update(loss.item(), imgs.size(0))

            if batch_idx % 100 == 0:
                writer.add_scalar('loss_text', loss_text,
                                  batch_idx + epoch * len(train_loader))
                writer.add_scalar('loss_kernel', loss_kernel,
                                  batch_idx + epoch * len(train_loader))
                writer.add_scalar('total_loss', loss,
                                  batch_idx + epoch * len(train_loader))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            score_text = cal_text_score(texts, gt_texts, training_masks,
                                        running_metric_text)
            score_kernel = cal_kernel_score(kernels, gt_kernels, gt_texts,
                                            training_masks,
                                            running_metric_kernel)

            batch_time.update(time.time() - end)
            end = time.time()

            if batch_idx % 20 == 0:
                output_log = '({batch}/{size}) Batch: {bt:.3f}s | TOTAL: {total:.0f}min | ETA: {eta:.0f}min | Loss: {loss:.4f} | Acc_t: {acc: .4f} | IOU_t: {iou_t: .4f} | IOU_k: {iou_k: .4f}'.format(
                    batch=batch_idx + 1,
                    size=len(train_loader),
                    bt=batch_time.avg,
                    total=batch_time.avg * batch_idx / 60.0,
                    eta=batch_time.avg * (len(train_loader) - batch_idx) /
                    60.0,
                    loss=losses.avg,
                    acc=score_text['Mean Acc'],
                    iou_t=score_text['Mean IoU'],
                    iou_k=score_kernel['Mean IoU'])
                print(output_log)
                sys.stdout.flush()

        return (losses.avg, score_text['Mean Acc'], score_kernel['Mean Acc'],
                score_text['Mean IoU'], score_kernel['Mean IoU'])

    def adjust_learning_rate(args, optimizer, epoch):
        global state
        if epoch in args.schedule:
            args.lr = args.lr * 0.1
            for param_group in optimizer.param_groups:
                param_group['lr'] = args.lr

    def save_checkpoint(state,
                        checkpoint='checkpoint',
                        filename='_checkpoint.pth.tar',
                        epoch=0):

        filepath = os.path.join(checkpoint, 'epoch_' + str(epoch) + filename)
        torch.save(state, filepath)

    if args.checkpoint == '':
        args.checkpoint = "checkpoints/ic15_%s_bs_%d_ep_%d" % (
            args.arch, args.batch_size, args.n_epoch)
    if args.pretrain:
        if 'synth' in args.pretrain:
            args.checkpoint += "_pretrain_synth"
        else:
            args.checkpoint += "_pretrain_LSVT"

    print('checkpoint path: %s' % args.checkpoint)
    print('init lr: %.8f' % args.lr)
    print('schedule: ', args.schedule)
    sys.stdout.flush()

    # if not os.path.isdir(args.checkpoint):
    #     os.makedirs(args.checkpoint)

    kernel_num = 7
    min_scale = 0.4
    start_epoch = 0

    data_loader = IC15Loader(is_transform=True,
                             img_size=args.img_size,
                             kernel_num=kernel_num,
                             min_scale=min_scale)
    train_loader = torch.utils.data.DataLoader(data_loader,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=3,
                                               drop_last=True,
                                               pin_memory=True)

    if args.arch == "resnet50":
        model = models.resnet50(pretrained=True, num_classes=kernel_num)
    elif args.arch == "resnet101":
        model = models.resnet101(pretrained=True, num_classes=kernel_num)
    elif args.arch == "resnet152":
        model = models.resnet152(pretrained=True, num_classes=kernel_num)

    model = torch.nn.DataParallel(model).cuda()

    if hasattr(model.module, 'optimizer'):
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=0.99,
                                    weight_decay=5e-4)

    title = 'icdar2015'
    if args.pretrain:
        print('Using pretrained model.')
        assert os.path.isfile(
            args.pretrain), 'Error: no checkpoint directory found!'
        print(args.pretrain)
        checkpoint = torch.load(args.pretrain)
        state = model.state_dict()
        for key in state.keys():
            if key in checkpoint.keys():
                state[key] = pretrained_model[key]
        model.load_state_dict(state)
        # model.load_state_dict(checkpoint['state_dict'])
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(
            ['Learning Rate', 'Train Loss', 'Train Acc.', 'Train IOU.'])
    elif args.resume:
        print('Resuming from checkpoint.')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(args.resume)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'),
                        title=title,
                        resume=True)
    else:
        print('Training from scratch.')
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(
            ['Learning Rate', 'Train Loss', 'Train Acc.', 'Train IOU.'])

    for epoch in range(start_epoch, args.n_epoch):
        adjust_learning_rate(args, optimizer, epoch)
        print('\nEpoch: [%d | %d] LR: %f' %
              (epoch + 1, args.n_epoch, optimizer.param_groups[0]['lr']))

        train_loss, train_te_acc, train_ke_acc, train_te_iou, train_ke_iou = train(
            train_loader, model, dice_loss, optimizer, epoch)
        if (epoch + 1) % 5 == 0:
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'lr': args.lr,
                    'optimizer': optimizer.state_dict(),
                },
                checkpoint=args.checkpoint,
                epoch=epoch)

        logger.append([
            optimizer.param_groups[0]['lr'], train_loss, train_te_acc,
            train_te_iou
        ])
    logger.close()