Example #1
0
def main(args):

    #initial setup
    if args.checkpoint == '':
        args.checkpoint = "checkpoints1/ic19val_%s_bs_%d_ep_%d" % (
            args.arch, args.batch_size, args.n_epoch)
    if args.pretrain:
        if 'synth' in args.pretrain:
            args.checkpoint += "_pretrain_synth"
        else:
            args.checkpoint += "_pretrain_ic17"

    print(('checkpoint path: %s' % args.checkpoint))
    print(('init lr: %.8f' % args.lr))
    print(('schedule: ', args.schedule))
    sys.stdout.flush()

    if not os.path.isdir(args.checkpoint):
        os.makedirs(args.checkpoint)

    kernel_num = 7
    min_scale = 0.4
    start_epoch = 0
    validation_split = 0.1
    random_seed = 42
    prev_val_loss = -1
    val_loss_list = []
    loggertf = tfLogger('./log/' + args.arch)
    #end

    #setup data loaders
    data_loader = IC19Loader(is_transform=True,
                             img_size=args.img_size,
                             kernel_num=kernel_num,
                             min_scale=min_scale)
    dataset_size = len(data_loader)
    indices = list(range(dataset_size))
    split = int(np.floor(validation_split * dataset_size))
    np.random.seed(random_seed)
    np.random.shuffle(indices)
    train_incidies, val_indices = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_incidies)
    validate_sampler = SubsetRandomSampler(val_indices)

    train_loader = torch.utils.data.DataLoader(data_loader,
                                               batch_size=args.batch_size,
                                               num_workers=3,
                                               drop_last=True,
                                               pin_memory=True,
                                               sampler=train_sampler)

    validate_loader = torch.utils.data.DataLoader(data_loader,
                                                  batch_size=args.batch_size,
                                                  num_workers=3,
                                                  drop_last=True,
                                                  pin_memory=True,
                                                  sampler=validate_sampler)
    #end

    #Setup architecture and optimizer
    if args.arch == "resnet50":
        model = models.resnet50(pretrained=True, num_classes=kernel_num)
    elif args.arch == "resnet101":
        model = models.resnet101(pretrained=True, num_classes=kernel_num)
    elif args.arch == "resnet152":
        model = models.resnet152(pretrained=True, num_classes=kernel_num)
    elif args.arch == "resPAnet50":
        model = models.resPAnet50(pretrained=True, num_classes=kernel_num)
    elif args.arch == "resPAnet101":
        model = models.resPAnet101(pretrained=True, num_classes=kernel_num)
    elif args.arch == "resPAnet152":
        model = models.resPAnet152(pretrained=True, num_classes=kernel_num)

    model = torch.nn.DataParallel(model).cuda()

    if hasattr(model.module, 'optimizer'):
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=0.99,
                                    weight_decay=5e-4)
    #end

    #options to resume/use pretrained model/train from scratch
    title = 'icdar2019MLT'
    if args.pretrain:
        print('Using pretrained model.')
        assert os.path.isfile(
            args.pretrain), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(args.pretrain)
        model.load_state_dict(checkpoint['state_dict'])
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names([
            'Learning Rate', 'Train Loss', 'Train Acc.', 'Train IOU.',
            'Validate Loss', 'Validate Acc', 'Validate IOU'
        ])
    elif args.resume:
        print('Resuming from checkpoint.')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(args.resume)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'),
                        title=title,
                        resume=True)
    else:
        print('Training from scratch.')
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names([
            'Learning Rate', 'Train Loss', 'Train Acc.', 'Train IOU.',
            'Validate Loss', 'Validate Acc', 'Validate IOU'
        ])
    #end

    #start training model
    for epoch in range(start_epoch, args.n_epoch):
        adjust_learning_rate(args, optimizer, epoch)
        print(('\nEpoch: [%d | %d] LR: %f' %
               (epoch + 1, args.n_epoch, optimizer.param_groups[0]['lr'])))

        train_loss, train_te_acc, train_ke_acc, train_te_iou, train_ke_iou = train(
            train_loader, model, dice_loss, optimizer, epoch, loggertf)
        val_loss, val_te_acc, val_ke_acc, val_te_iou, val_ke_iou = validate(
            validate_loader, model, dice_loss)

        #logging on tensorboard
        loggertf.scalar_summary('Training/Accuracy', train_te_acc, epoch + 1)
        loggertf.scalar_summary('Training/Loss', train_loss, epoch + 1)
        loggertf.scalar_summary('Training/IoU', train_te_iou, epoch + 1)
        loggertf.scalar_summary('Validation/Accuracy', val_te_acc, epoch + 1)
        loggertf.scalar_summary('Validation/Loss', val_loss, epoch + 1)
        loggertf.scalar_summary('Validation/IoU', val_te_iou, epoch + 1)
        #end

        #Boring Book Keeping
        print(("End of Epoch %d", epoch + 1))
        print((
            "Train Loss: {loss:.4f} | Train Acc: {acc: .4f} | Train IOU: {iou_t: .4f}"
            .format(loss=train_loss, acc=train_te_acc, iou_t=train_te_iou)))
        print((
            "Validation Loss: {loss:.4f} | Validation Acc: {acc: .4f} | Validation IOU: {iou_t: .4f}"
            .format(loss=val_loss, acc=val_te_acc, iou_t=val_te_iou)))
        #end

        #Saving improving and Best Models
        val_loss_list.append(val_loss)
        if (val_loss < prev_val_loss or prev_val_loss == -1):
            checkpointname = "{loss:.3f}".format(
                loss=val_loss) + "_epoch" + str(epoch +
                                                1) + "_checkpoint.pth.tar"
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'lr': args.lr,
                    'optimizer': optimizer.state_dict(),
                },
                checkpoint=args.checkpoint,
                filename=checkpointname)
        if (val_loss < min(val_loss_list)):
            checkpointname = "best_checkpoint.pth.tar"
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'lr': args.lr,
                    'optimizer': optimizer.state_dict(),
                },
                checkpoint=args.checkpoint,
                filename=checkpointname)
        #end
        prev_val_loss = val_loss
        logger.append([
            optimizer.param_groups[0]['lr'], train_loss, train_te_acc,
            train_te_iou, val_loss, val_te_acc, val_te_iou
        ])
    #end traing model
    logger.close()
Example #2
0
def main(args):
    if args.checkpoint == '':
        args.checkpoint = "checkpoints/ctw1500_%s_bs_%d_ep_%d" % (
            args.arch, args.batch_size, args.n_epoch)
    if args.pretrain:
        if 'synth' in args.pretrain:
            args.checkpoint += "_pretrain_synth"
        else:
            args.checkpoint += "_pretrain_ic17"

    print('checkpoint path: %s' % args.checkpoint)
    print('init lr: %.8f' % args.lr)
    print('schedule: ', args.schedule)
    sys.stdout.flush()

    if not os.path.isdir(args.checkpoint):
        os.makedirs(args.checkpoint)

    kernel_num = 7
    min_scale = 0.4
    start_epoch = 0

    data_loader = CTW1500Loader(is_transform=True,
                                img_size=args.img_size,
                                kernel_num=kernel_num,
                                min_scale=min_scale)
    train_loader = torch.utils.data.DataLoader(data_loader,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=3,
                                               drop_last=True,
                                               pin_memory=True)

    if args.arch == "resnet50":
        model = models.resnet50(pretrained=True, num_classes=kernel_num)
    elif args.arch == "resnet101":
        model = models.resnet101(pretrained=True, num_classes=kernel_num)
    elif args.arch == "resnet152":
        model = models.resnet152(pretrained=True, num_classes=kernel_num)

    model = torch.nn.DataParallel(model).cuda()

    if hasattr(model.module, 'optimizer'):
        optimizer = model.module.optimizer
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=0.99,
                                    weight_decay=5e-4)

    title = 'CTW1500'
    if args.pretrain:
        print('Using pretrained model.')
        assert os.path.isfile(
            args.pretrain), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(args.pretrain)
        model.load_state_dict(checkpoint['state_dict'])
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(
            ['Learning Rate', 'Train Loss', 'Train Acc.', 'Train IOU.'])
    elif args.resume:
        print('Resuming from checkpoint.')
        assert os.path.isfile(
            args.resume), 'Error: no checkpoint directory found!'
        checkpoint = torch.load(args.resume)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'),
                        title=title,
                        resume=True)
    else:
        print('Training from scratch.')
        logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title=title)
        logger.set_names(
            ['Learning Rate', 'Train Loss', 'Train Acc.', 'Train IOU.'])

    for epoch in range(start_epoch, args.n_epoch):
        adjust_learning_rate(args, optimizer, epoch)
        print('\nEpoch: [%d | %d] LR: %f' %
              (epoch + 1, args.n_epoch, optimizer.param_groups[0]['lr']))

        train_loss, train_te_acc, train_ke_acc, train_te_iou, train_ke_iou = train(
            train_loader, model, dice_loss, optimizer, epoch)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'lr': args.lr,
                'optimizer': optimizer.state_dict(),
            },
            checkpoint=args.checkpoint)

        logger.append([
            optimizer.param_groups[0]['lr'], train_loss, train_te_acc,
            train_te_iou
        ])
    logger.close()