コード例 #1
0
ファイル: train.py プロジェクト: zhengyussss/PAN-PSEnet
    crnn.load_state_dict(torch.load(params.pretrained))
print(crnn)

# -------------------------------------------------------------------------------------------------
converter = utils.strLabelConverter(alpha)
criterion = CTCLoss()

image = torch.FloatTensor(params.batchSize, 3, params.imgH, params.imgH)
text = torch.IntTensor(params.batchSize * 5)
length = torch.IntTensor(params.batchSize)
if params.cuda and torch.cuda.is_available():
    crnn.cuda()
    if params.multi_gpu:
        crnn = torch.nn.DataParallel(crnn, device_ids=range(params.ngpu))
    image = image.cuda()
    criterion = criterion.cuda()
image = Variable(image)
text = Variable(text)
length = Variable(length)

# loss averager
loss_avg = utils.averager()

# setup optimizer
if params.adam:
    optimizer = optim.Adam(crnn.parameters(),
                           lr=params.lr,
                           betas=(params.beta1, 0.999))
elif params.adadelta:
    optimizer = optim.Adadelta(crnn.parameters())
else:
コード例 #2
0
def run(args):
    model = CRNN(image_height=args.image_height,
                 num_of_channels=args.num_of_channels,
                 num_of_classes=args.num_of_classes,
                 num_of_lstm_hidden_units=args.num_of_lstm_hidden_units)

    if torch.cuda.is_available() and not args.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")

    optimizer_state = None
    if args.pretrained != '':
        print('loading pretrained model from %s' % args.pretrained)
        optimizer_state = load_from_snapshot(args, model=model)
    else:
        model.apply(weights_init)
    print(model)

    trainer = Trainer()
    criterion = CTCLoss(zero_infinity=True, reduction='mean')

    train_image = torch.FloatTensor(args.batch_size, 3, args.image_height, 512)
    test_image = torch.FloatTensor(args.test_batch_size, 3, args.image_height, 512)

    if args.cuda:
        model.cuda()
        model = torch.nn.DataParallel(model, device_ids=range(args.ngpu))
        train_image = train_image.cuda()
        test_image = test_image.cuda()
        criterion = criterion.cuda()

    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
    if optimizer_state is not None:
        optimizer.load_state_dict(optimizer_state)

    train_image = Variable(train_image)
    test_image = Variable(test_image)

    val_dataset = WordsDataset(min_page_index=600,
                               max_page_index=769,
                               data_set_dir=args.train_dataset_dir,
                               transform=transforms.Compose([ToFloatTensor()]))

    val_loader = DataLoader(dataset=val_dataset,
                            batch_size=args.test_batch_size,
                            shuffle=False,
                            num_workers=1)

    train_dataset = WordsDataset(min_page_index=0,
                                 max_page_index=600,
                                 data_set_dir=args.train_dataset_dir,
                                 transform=transforms.Compose([
                                     Erode(),
                                     Rotate(),
                                     ApplyAveraging(),
                                     GaussNoise(),
                                     ToFloatTensor()
                                 ]))

    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=1)

    test_dataset = TestWordsDataset(data_set_path=args.test_dataset_path,
                                    transform=ToFloatTensor())

    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=args.test_batch_size,
                             shuffle=False,
                             num_workers=1)

    for epoch in range(1, args.epochs + 1):
        if epoch % 5 == 0:
            val_loss, val_accuracy = trainer.test(criterion=criterion,
                                                  model=model,
                                                  test_loader=val_loader,
                                                  test_image=test_image)

            print('\nValidation set: Average loss: {:.4f}, Accuracy: {:.4f}\n'.
                  format(val_loss,
                         val_accuracy
                         )
                  )

            val_losses_path = os.path.join(args.val_loss, 'losses.npy')
            try:
                val_losses_file = list(np.load(val_losses_path))
                np.save(val_losses_path, np.asarray(val_losses_file + [val_loss]))
            except FileNotFoundError:
                np.save(val_losses_path, np.asarray([val_loss]))

            test_loss, test_accuracy = trainer.test(criterion=criterion,
                                                    model=model,
                                                    test_loader=test_loader,
                                                    test_image=test_image)

            print('\nTest set: Average loss: {:.4f}, Accuracy: {:.4f}\n'.
                  format(test_loss,
                         test_accuracy
                         )
                  )

            test_losses_path = os.path.join(args.test_loss, 'losses.npy')
            try:
                test_losses_file = list(np.load(test_losses_path))
                np.save(test_losses_path, np.asarray(test_losses_file + [test_loss]))
            except FileNotFoundError:
                np.save(test_losses_path, np.asarray([test_loss]))

        train_losses = trainer.train(args=args,
                                     criterion=criterion,
                                     model=model,
                                     train_loader=train_loader,
                                     optimizer=optimizer,
                                     epoch=epoch,
                                     train_image=train_image)

        train_losses_path = os.path.join(args.train_loss, 'losses.npy')
        try:
            val_losses_file = list(np.load(train_losses_path))
            np.save(train_losses_path, np.asarray(val_losses_file + train_losses))
        except FileNotFoundError:
            np.save(train_losses_path, np.asarray(train_losses))

        if args.save_model != '':
            state = {'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}
            torch.save(state, os.path.join(args.save_model, 'crnn' + str(epoch) + '.pt'))
コード例 #3
0
def test_grcnn(config_yaml):
    import sys
    sys.path.append('./recognition_model/GRCNN')

    import random
    import torch.backends.cudnn as cudnn
    import torch.optim as optim
    import torch.utils.data
    import numpy as np
    import Levenshtein
    from torch.autograd import Variable
    # from warpctc_pytorch import CTCLoss
    # from GRCNN.utils.Logger import Logger
    from torch.nn import CTCLoss
    import GRCNN.utils.keys as keys
    import GRCNN.utils.util as util
    import dataset
    import GRCNN.models.crann as crann
    import yaml
    import os
    import time

    def adjust_lr(optimizer, base_lr, epoch, step):
        lr = base_lr * (0.1**(epoch // step))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    def train(model, train_loader, val_loader, criterion, optimizer, opt,
              converter, epoch):
        # Set up training phase.
        interval = int(len(train_loader) / opt['SAVE_FREQ'])
        model.train()

        for i, (cpu_images, cpu_gt) in enumerate(train_loader, 1):
            # print('iter {} ...'.format(i))
            bsz = cpu_images.size(0)
            text, text_len = converter.encode(cpu_gt)
            v_images = Variable(cpu_images.cuda())
            v_gt = Variable(text)
            v_gt_len = Variable(text_len)

            model = model.cuda()
            predict = model(v_images)
            predict_len = Variable(torch.IntTensor([predict.size(0)] * bsz))

            loss = criterion(predict, v_gt, predict_len, v_gt_len)
            # logger.scalar_summary('train_loss', loss.data[0], i + epoch * len(train_loader))

            # Compute accuracy
            _, acc = predict.max(2)
            acc = acc.transpose(1, 0).contiguous().view(-1)
            sim_preds = converter.decode(acc.data, predict_len.data, raw=False)
            n_correct = 0
            for pred, target in zip(sim_preds, cpu_gt):
                if pred.lower() == target.lower():
                    n_correct += 1
            accuracy = n_correct / float(bsz)

            # logger.scalar_summary('train_accuray', accuracy, i + epoch * len(train_loader))

            # Backpropagate
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i % interval == 0 and i > 0:
                print('Training @ Epoch: [{0}][{1}/{2}]; Train Accuracy:{3}'.
                      format(epoch, i, len(train_loader), accuracy))
                val(model, val_loader, criterion, converter, epoch,
                    i + epoch * len(train_loader), False)
                model.train()
                freq = int(i / interval)
                save_checkpoint(
                    {
                        'epoch': epoch,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict()
                    }, '{0}/crann_{1}_{2}.pth'.format(opt['SAVE_PATH'], epoch,
                                                      freq))

    def val(model, ds_loader, criterion, converter, epoch, iteration, valonly):
        print('Start validating on epoch:{0}/iter:{1}...'.format(
            epoch, iteration))
        model.eval()
        ave_loss = 0.0
        ave_accuracy = 0.0
        err_sim = []
        err_gt = []
        distance = 0
        length = 0
        with torch.no_grad():
            for i, (cpu_images, cpu_gt) in enumerate(ds_loader):
                bsz = cpu_images.size(0)
                text, text_len = converter.encode(cpu_gt)
                v_Images = Variable(cpu_images.cuda())
                v_gt = Variable(text)
                v_gt_len = Variable(text_len)

                predict = model(v_Images)
                predict_len = Variable(torch.IntTensor([predict.size(0)] *
                                                       bsz))
                loss = criterion(predict, v_gt, predict_len, v_gt_len)
                ave_loss += loss.data[0]

                # Compute accuracy
                _, acc = predict.max(2)
                acc = acc.transpose(1, 0).contiguous().view(-1)
                sim_preds = converter.decode(acc.data,
                                             predict_len.data,
                                             raw=False)
                n_correct = 0
                for pred, target in zip(sim_preds, cpu_gt):
                    length += len(target)
                    if pred.lower() == target.lower():
                        n_correct += 1.0
                    else:
                        err_sim.append(pred)
                        err_gt.append(target)
                ave_accuracy += n_correct / float(bsz)
            for pred, gt in zip(err_sim, err_gt):
                print('pred: %-20s, gt: %-20s' % (pred, gt))
                distance += Levenshtein.distance(pred, gt)
            # print("The Levenshtein distance is:",distance)
            print("The average Levenshtein distance is:", distance / length)
            if not valonly:
                pass
                # logger.scalar_summary('validation_loss', ave_loss / len(ds_loader), iteration)
                #logger.scalar_summary('validation_accuracy', ave_accuracy / len(ds_loader), iteration)
                # logger.scalar_summary('Ave_Levenshtein_distance', distance / length, iteration)
            print(
                'Testing Accuracy:{0}, Testing Loss:{1} @ Epoch{2}, Iteration{3}'
                .format(ave_accuracy / len(ds_loader),
                        ave_loss / len(ds_loader), epoch, iteration))

    def save_checkpoint(state, file_name):
        # time.sleep(0.01)
        # torch.save(state, file_name)
        try:
            time.sleep(0.01)
            torch.save(state, file_name)
        except RuntimeError:
            print("RuntimeError")
            pass

    '''
    Training/Finetune CNN_RNN_Attention Model.
    '''
    #### Load config settings. ####
    f = open(config_yaml, encoding='utf-8')
    opt = yaml.load(f)
    # if os.path.isdir(opt['LOGGER_PATH']) == False:
    #     os.mkdir(opt['LOGGER_PATH'])
    # logger = Logger(opt['LOGGER_PATH'])
    if os.path.isdir(opt['SAVE_PATH']) == False:
        os.system('mkdir -p {0}'.format(opt['SAVE_PATH']))
    manualSeed = random.randint(1, 10000)
    random.seed(manualSeed)
    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    cudnn.benchmark = True

    #### Set up DataLoader. ####
    train_cfg = opt['TRAIN']
    ds_cfg = train_cfg['DATA_SOURCE']
    print('Building up dataset:{}'.format(ds_cfg['TYPE']))
    if ds_cfg['TYPE'] == 'SYN_DATA':
        text_gen = util.TextGenerator(ds_cfg['GEN_SET'], ds_cfg['GEN_LEN'])
        ds_train = dataset.synthDataset(ds_cfg['FONT_ROOT'],
                                        ds_cfg['FONT_SIZE'], text_gen)
    elif ds_cfg['TYPE'] == 'IMG_DATA':
        ds_train = dataset.trainDataset(
            ds_cfg['IMG_ROOT'], ds_cfg['TRAIN_SET'],
            transform=None)  # dataset.graybackNormalize()
    assert ds_train
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=train_cfg['BATCH_SIZE'],
        shuffle=True,
        sampler=None,
        num_workers=opt['WORKERS'],
        collate_fn=dataset.alignCollate(imgH=train_cfg['IMG_H'],
                                        imgW=train_cfg['MAX_W']))

    val_cfg = opt['VALIDATION']
    ds_val = dataset.testDataset(val_cfg['IMG_ROOT'],
                                 val_cfg['VAL_SET'],
                                 transform=None)  # dataset.graybackNormalize()
    assert ds_val
    val_loader = torch.utils.data.DataLoader(ds_val,
                                             batch_size=32,
                                             shuffle=False,
                                             num_workers=opt['WORKERS'],
                                             collate_fn=dataset.alignCollate(
                                                 imgH=train_cfg['IMG_H'],
                                                 imgW=train_cfg['MAX_W']))

    #### Model construction and Initialization. ####
    alphabet = keys.alphabet
    nClass = len(alphabet) + 1

    if opt['N_GPU'] > 1:
        opt['RNN']['multi_gpu'] = True
    else:
        opt['RNN']['multi_gpu'] = False
    model = crann.CRANN(opt, nClass)
    # print(model)

    #### Train/Val the model. ####
    converter = util.strLabelConverter(alphabet)
    # from warpctc_pytorch import CTCLoss
    criterion = CTCLoss()
    if opt['CUDA']:
        model.cuda()
        criterion.cuda()

    if opt['OPTIMIZER'] == 'RMSprop':
        optimizer = optim.RMSprop(model.parameters(), lr=opt['TRAIN']['LR'])
    elif opt['OPTIMIZER'] == 'Adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=opt['TRAIN']['LR'],
                               betas=(opt['TRAIN']['BETA1'], 0.999))
    elif opt['OPTIMIZER'] == 'SGD':
        optimizer = optim.SGD(model.parameters(), lr=opt['TRAIN']['LR'])
    else:
        optimizer = optim.Adadelta(model.parameters(), lr=opt['TRAIN']['LR'])

    start_epoch = 0
    assert opt[
        'VAL_ONLY'] == True, "You should set the variable 'VAL_ONLY to True'"
    if opt['VAL_ONLY']:
        print('=>loading pretrained model from %s for val only.' %
              opt['CRANN'])
        checkpoint = torch.load(opt['CRANN'])
        model.load_state_dict(checkpoint['state_dict'])
        val(model, val_loader, criterion, converter, 0, 0, True)
    elif opt['FINETUNE']:
        print('=>loading pretrained model from %s for finetuen.' %
              opt['CRANN'])
        checkpoint = torch.load(opt['CRANN'])
        # model.load_state_dict(checkpoint['state_dict'])
        model_dict = model.state_dict()
        # print(model_dict.keys())
        cnn_dict = {
            "cnn." + k: v
            for k, v in checkpoint.items() if "cnn." + k in model_dict
        }
        model_dict.update(cnn_dict)
        model.load_state_dict(model_dict)
        for epoch in range(start_epoch, opt['EPOCHS']):
            adjust_lr(optimizer, opt['TRAIN']['LR'], epoch, opt['STEP'])
            train(model, train_loader, val_loader, criterion, optimizer, opt,
                  converter, epoch)
    elif opt['RESUME']:
        print('=>loading checkpoint from %s for resume training.' %
              opt['CRANN'])
        checkpoint = torch.load(opt['CRANN'])
        start_epoch = checkpoint['epoch'] + 1
        print('resume from epoch:{}'.format(start_epoch))
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        for epoch in range(start_epoch, opt['EPOCHS']):
            adjust_lr(optimizer, opt['TRAIN']['LR'], epoch, opt['STEP'])
            train(model, train_loader, val_loader, criterion, optimizer, opt,
                  converter, epoch)
    else:
        print('train from scratch.')
        for epoch in range(start_epoch, opt['EPOCHS']):
            adjust_lr(optimizer, opt['TRAIN']['LR'], epoch, opt['STEP'])
            train(model, train_loader, val_loader, criterion, optimizer, opt,
                  converter, epoch)
コード例 #4
0
ファイル: base_train.py プロジェクト: Sanny26/indic-htr
class BaseHTR(object):
    def __init__(self, opt, dataset_name='iam', reset_log=False):
        self.opt = opt
        self.mode = self.opt.mode
        self.dataset_name = dataset_name
        self.stn_nc = self.opt.stn_nc
        self.cnn_nc = self.opt.cnn_nc
        self.nheads = self.opt.nheads
        self.criterion = CTCLoss(blank=0, reduction='sum', zero_infinity=True)
        self.label_transform = self.init_label_transform()
        self.test_transforms = self.init_test_transforms()
        self.train_transforms = self.init_train_transforms()
        self.val1_iter = self.opt.val1_iter # Number of train data batches that will be validated
        self.val2_iter = self.opt.val2_iter # Number of validation data batches that will be validated
        self.stn_attn = None
        self.val_metric = 'cer'
        self.use_loc_bn = False
        self.CNN = 'ResCRNN'
        self.loc_block = 'LocNet'
        self.identity_matrix = torch.tensor([1, 0, 0, 0, 1, 0],
                                       dtype=torch.float).cuda()
        if self.mode == 'train':
            if len(self.opt.trainRoot) == 0:
                self.train_root = "/ssd_scratch/cvit/santhoshini/{}-train-lmdb".format(self.dataset_name)
            else:
                self.train_root = self.opt.trainRoot
        if len(self.opt.valRoot) == 0:
            self.test_root = "/ssd_scratch/cvit/santhoshini/{}-test-lmdb".format(self.dataset_name)
        else:
            self.test_root = self.opt.valRoot

        if not os.path.exists(self.opt.node_dir):
            os.makedirs(self.opt.node_dir)
        elif reset_log:
            shutil.rmtree(self.opt.node_dir)
            os.makedirs(self.opt.node_dir)

        random.seed(self.opt.manualSeed)
        np.random.seed(self.opt.manualSeed)
        torch.manual_seed(self.opt.manualSeed)

        # cudnn.benchmark = True
        cudnn.deterministic = True
        cudnn.benchmark = False
        cudnn.enabled = True
        # print('CudNN enabled', cudnn.enabled)

        if torch.cuda.is_available() and not self.opt.cuda:
            print("WARNING: You have a CUDA device, so you should probably run with --cuda")
        else:
            self.opt.gpu_id = list(map(int, self.opt.gpu_id.split(',')))
            torch.cuda.set_device(self.opt.gpu_id[0])


    def run(self):
        if self.mode == "train":
            # print(self.train_root, self.test_root)
            self.train_data, self.train_loader = self.get_data_loader(self.train_root,
                                                                      self.train_transforms,
                                                                      self.label_transform)
            self.test_data, self.test_loader = self.get_data_loader(self.test_root,
                                                                      self.test_transforms,
                                                                      self.label_transform)
            self.converter = utils.strLabelConverter(self.test_data.id2char,
                                                     self.test_data.char2id,
                                                     self.test_data.ctc_blank)
            check_data(self.train_loader, '{}train'.format(self.dataset_name))
            check_data(self.test_loader, '{}val'.format(self.dataset_name))
            # pdb.set_trace()
            self.nclass = self.test_data.rec_num_classes
            self.model, self.parameters = self.get_model()
            self.init_variables()
            self.init_train_params()
            print('Classes: ', self.test_data.voc)
            print('#Train Samples: ', self.train_data.nSamples)
            print('#Val Samples: ', self.test_data.nSamples)
            self.train()
        elif self.mode == "test":
            self.test_data, self.test_loader = self.get_data_loader(self.test_root,
                                                                      self.test_transforms,
                                                                      self.label_transform)
            self.converter = utils.strLabelConverter(self.test_data.id2char,
                                                     self.test_data.char2id,
                                                     self.test_data.ctc_blank)
            check_data(self.test_loader, '{}test'.format(self.dataset_name))
            self.nclass = self.test_data.rec_num_classes
            self.model, self.parameters = self.get_model()
            self.init_variables()
            print('Classes: ', self.test_data.voc)
            print('#Test Samples: ', self.test_data.nSamples)
            self.eval(self.test_data)

    def init_train_transforms(self):
        T = Compose([Rescale((self.opt.imgH, self.opt.imgW)),ElasticTransformation(0.7),ToTensor()])
        return T

    def init_test_transforms(self):
        T = Compose([Rescale((self.opt.imgH, self.opt.imgW)),ToTensor()])
        return T

    def init_label_transform(self):
        T = None
        return T

    def init_variables(self):
        self.image = torch.FloatTensor(self.opt.batchSize, 3, self.opt.imgH, self.opt.imgH)
        self.text = torch.LongTensor(self.opt.batchSize * 5)
        self.length = torch.LongTensor(self.opt.batchSize)
        if self.opt.cuda:
            self.image = self.image.cuda()
            self.criterion = self.criterion.cuda()
            self.text = self.text.cuda()
            self.length = self.length.cuda()
        self.image = Variable(self.image)
        self.text = Variable(self.text)
        self.length = Variable(self.length)

    def init_train_params(self):
        if self.opt.adam:
            self.optimizer = optim.Adam(self.parameters, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
        elif self.opt.adadelta:
            self.optimizer = optim.Adadelta(self.parameters, lr=self.opt.lr)
        elif self.opt.rmsprop:
            self.optimizer = optim.RMSprop(self.parameters, lr=self.opt.lr)
        else:
            self.optimizer = optim.SGD(self.parameters, lr=self.opt.lr, momentum=self.opt.momentum)

        if self.opt.StepLR:
            self.scheduler = StepLR(self.optimizer, step_size=20000, gamma=0.5)
        else:
            self.scheduler = None
        # scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.00001, max_lr=0.001,
        #                                             cycle_momentum=False)
        print(self.optimizer)
        return

    def get_model(self):
        crnn = ModelBuilder(self.opt.imgH, self.opt.imgW, self.opt.tps_inputsize,
                        self.opt.tps_outputsize, self.opt.num_control_points, self.opt.tps_margins, self.opt.stn_activation,
                        self.opt.nh, self.stn_nc, self.cnn_nc, self.nclass, STN_type=self.opt.STN_type,
                        nheads=self.nheads, stn_attn=self.stn_attn, use_loc_bn=self.use_loc_bn, loc_block = self.loc_block,
                        CNN=self.CNN)
        if self.opt.cuda:
            crnn.cuda()
            crnn = torch.nn.DataParallel(crnn, device_ids=self.opt.gpu_id, dim=1)
        else:
            crnn = torch.nn.DataParallel(crnn, device_ids=self.opt.gpu_id)
        if self.opt.pretrained != '':
            if self.opt.transfer:
                d_params = crnn.state_dict()
                s_params = torch.load(self.opt.pretrained)
                for name1 in s_params:
                    param1 = s_params[name1]
                    try:
                        d_params[name1].data.copy_(param1.data)
                    except:
                        print('Skipping weight ', name1)
                        continue
                crnn.load_state_dict(d_params)
            else:
                print('Using pretrained model', self.opt.pretrained)
                crnn.load_state_dict(torch.load(self.opt.pretrained))
        else:
            crnn.apply(weights_init)
        return crnn, crnn.parameters()

    def get_data_loader(self, root, im_transforms, label_transforms, num_samples=np.inf):
        data = dataset.lmdbDataset(root=root, voc=self.opt.alphabet, num_samples=num_samples,
                                   transform=im_transforms, label_transform=label_transforms,
                                   voc_type=self.opt.alphabet_type, lowercase=self.opt.lowercase,
                                   alphanumeric=self.opt.alphanumeric, return_list=True)
        if not self.opt.random_sample:
            sampler = dataset.randomSequentialSampler(data, self.opt.batchSize)
        else:
            sampler = None
        data_loader = torch.utils.data.DataLoader(data, batch_size=self.opt.batchSize,
                                                shuffle=True, sampler=sampler,
                                                num_workers=int(self.opt.workers),
                                                collate_fn=dataset.collatedict())
        return data, data_loader

    def train(self, max_iter=np.inf):
        loss_avg = utils.averager()
        prev_cer = 100
        prev_wer = 100
        write_info(self.model, self.opt)
        self.writer = Writer(self.opt.lr, self.opt.nepoch, self.opt.node_dir, use_tb=self.opt.use_tb)
        self.iterations = 0
        for epoch in range(self.opt.nepoch):
            self.writer.epoch = epoch
            self.writer.nbatches = len(self.train_loader)
            self.train_iter = iter(self.train_loader)
            i = 0
            while i < len(self.train_loader):
                if self.iterations % self.opt.valInterval == 0:
                    valloss, val_CER, val_WER = self.eval(self.test_data, max_iter=self.val2_iter)
                    self.writer.update_valloss(valloss.val().item(), val_CER)
                    # trloss, trER = self.eval(self.train_data, max_iter=self.val1_iter)
                    # self.writer.update_trloss2(trloss.val().item(), trER)
                    torch.save(
                            self.model.state_dict(), '{0}/{1}.pth'.format(self.opt.node_dir,'latest'))
                    if val_CER < prev_cer:
                        torch.save(
                            self.model.state_dict(), '{0}/{1}.pth'.format(self.opt.node_dir,'best_cer'))
                        prev_cer = val_CER
                        self.writer.update_best_er(val_CER, self.iterations)
                    if val_WER < prev_wer:
                        torch.save(
                            self.model.state_dict(), '{0}/{1}.pth'.format(self.opt.node_dir,'best_wer'))
                        prev_wer = val_WER
                        # self.writer.update_best_er(val_WER, self.iterations)
                cost = self.trainBatch()
                loss_avg.add(cost)
                self.iterations += 1
                i += 1
                self.writer.iterations = self.iterations
                self.writer.batch = i

                if self.iterations % self.opt.displayInterval == 0:
                    self.writer.update_trloss(loss_avg.val().item())
                    loss_avg.reset()
        self.writer.end()
        return

    def forward_sample(self, data):
        cpu_images, cpu_texts = data
        utils.loadData(self.image, cpu_images)
        t, l = self.converter.encode(cpu_texts)
        utils.loadData(self.text, t)
        utils.loadData(self.length, l)
        output_dict = self.model(self.image)
        batch_size = cpu_images.size(0)
        output_dict['batch_size'] = batch_size
        output_dict['gt'] = cpu_texts
        return output_dict

    def get_loss(self, data):
        preds = data['preds']
        batch_size = data['batch_size']
        preds_size = data['preds_size']
        torch.backends.cudnn.enabled = False
        cost = self.criterion(preds, self.text, preds_size, self.length) / batch_size
        torch.backends.cudnn.enabled = True
        return cost

    def decoder(self, preds, preds_size):
        if self.opt.beamdecoder:
            sim_preds = []
            for j in range(preds.size()[1]):
                probs = preds[:, j, :]
                probs = torch.cat([probs[:, 1:], probs[:, 0].unsqueeze(1)], dim=1).cpu().detach().numpy()
                sim_preds.append(ctc_bs.ctcBeamSearch(probs, self.test_data.voc, None))
        else:
            _, preds = preds.max(2)
            preds = preds.transpose(1, 0).contiguous().view(-1)
            sim_preds = self.converter.decode(preds.data, preds_size.data, raw=False)
        return sim_preds

    def eval(self, data, max_iter=np.inf):
        data_loader = torch.utils.data.DataLoader(data, batch_size=self.opt.batchSize,
                                                num_workers=int(self.opt.workers),
                                                pin_memory=True,
                                                collate_fn=dataset.collatedict())
        self.model.eval()
        gts = []
        decoded_preds = []
        val_iter = iter(data_loader)
        tc = 0
        wc = 0
        ww = 0
        tw = 0
        loss_avg = utils.averager()
        max_iter = min(max_iter, len(data_loader))
        with torch.no_grad():
            # print('-------Current LR-----')
            # for param_group in self.optimizer.param_groups:
            #     print(param_group['lr'])
            # print('---------------------')
            for i in range(max_iter):
                if self.opt.mode == 'test':
                    print('%d / %d' % (i, len(data_loader)), end='\r')
                output_dict = self.forward_sample(val_iter.next())
                batch_size = output_dict['batch_size']
                preds = F.log_softmax(output_dict['probs'], 2)
                preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
                cost = self.get_loss({'preds': preds, 'batch_size': batch_size,
                                      'preds_size': preds_size, 'params':output_dict['params']})
                loss_avg.add(cost)
                decoded_pred = self.decoder(preds, preds_size)
                gts += list(output_dict['gt'])
                decoded_preds += list(decoded_pred)

        if self.mode == "train":
            pcounter = 0
            for target, pred in zip(gts, decoded_preds):
                if pcounter < 5:
                    print('Gt:   ', target)
                    print('Pred: ', pred)
                    pcounter += 1
                if target!=pred:
                    ww += 1
                tw += 1
                wc += utils.levenshtein(target, pred)
                tc += len(target)
            wer = (ww / tw)*100
            cer = (wc / tc)*100
            return loss_avg, cer, wer
        else:
            f = open(self.opt.out, 'w')
            for target, pred in zip(gts, decoded_preds):
                f.write('{}\n{}\n'.format(pred, target))
            f.close()
            print('Generated predictions for {} samples'.format(self.test_data.nSamples))
        return

    def trainBatch(self):
        self.model.train()
        output_dict = self.forward_sample(self.train_iter.next())
        batch_size = output_dict['batch_size']
        preds = F.log_softmax(output_dict['probs'], 2)
        preds_size = Variable(torch.LongTensor([preds.size(0)] * batch_size))
        cost = self.get_loss({'preds': preds, 'batch_size':batch_size, 'preds_size':preds_size, 'params':output_dict['params']})
        if torch.isnan(cost):
            pdb.set_trace()
        self.model.zero_grad()
        cost.backward()

        # grad_zero_flag = 0
        # for name, param in self.model.named_parameters():
        #     if param.grad.sum() == 0:
        #         if name.find('bias')<0:
        #             print(name)
        #             grad_zero_flag = 1
        # if grad_zero_flag:
        #     print(f'---------{self.iterations} training-------------')

        # pdb.set_trace()
        # try:
        #     tparams = torch.stack(output_dict['params'], axis=1)
        # except:
        #     tparams = output_dict['params']
        # if torch.all(tparams[0] == tparams.mean(axis=0)):
        #     if self.iterations != 0:
        #         pdb.set_trace()

        self.optimizer.step()
        if self.scheduler:
            self.scheduler.step()

        return cost
コード例 #5
0
ファイル: train.py プロジェクト: tomowang/id-card-ocr
def train(field):
    alphabet = ''.join(json.load(open('./cn-alphabet.json', 'rb')))
    nclass = len(alphabet) + 1  # add the dash -
    batch_size = BATCH_SIZE
    if field == 'address' or field == 'psb':
        batch_size = 1  # image length varies

    converter = LabelConverter(alphabet)
    criterion = CTCLoss(zero_infinity=True)

    crnn = CRNN(IMAGE_HEIGHT, nc, nclass, number_hidden)
    crnn.apply(weights_init)

    image_transform = transforms.Compose([
        Rescale(IMAGE_HEIGHT),
        transforms.ToTensor(),
        Normalize()
    ])

    dataset = LmdbDataset(db_path, field, image_transform)
    dataloader = DataLoader(dataset, batch_size=batch_size,
                            shuffle=True, num_workers=4)

    image = torch.FloatTensor(batch_size, 3, IMAGE_HEIGHT, IMAGE_HEIGHT)
    text = torch.IntTensor(batch_size * 5)
    length = torch.IntTensor(batch_size)

    image = Variable(image)
    text = Variable(text)
    length = Variable(length)

    loss_avg = utils.averager()
    optimizer = optim.RMSprop(crnn.parameters(), lr=lr)

    if torch.cuda.is_available():
        crnn.cuda()
        crnn = nn.DataParallel(crnn)
        image = image.cuda()
        criterion = criterion.cuda()

    def train_batch(net, iteration):
        data = iteration.next()
        cpu_images, cpu_texts = data
        batch_size = cpu_images.size(0)
        utils.load_data(image, cpu_images)
        t, l = converter.encode(cpu_texts)
        utils.load_data(text, t)
        utils.load_data(length, l)

        preds = crnn(image)
        preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
        cost = criterion(preds, text, preds_size, length) / batch_size
        crnn.zero_grad()
        cost.backward()
        optimizer.step()
        return cost

    nepoch = 25
    for epoch in range(nepoch):
        train_iter = iter(dataloader)
        i = 0
        while i < len(dataloader):
            for p in crnn.parameters():
                p.requires_grad = True
            crnn.train()

            cost = train_batch(crnn, train_iter)
            loss_avg.add(cost)
            i += 1

            if i % 500 == 0:
                print('%s [%d/%d][%d/%d] Loss: %f' %
                        (datetime.datetime.now(), epoch, nepoch, i, len(dataloader), loss_avg.val()))
                loss_avg.reset()

            # do checkpointing
            if i % 500 == 0:
                torch.save(
                    crnn.state_dict(), f'{model_path}crnn_{field}_{epoch}_{i}.pth')
コード例 #6
0
            'Epoch[%d/%d] lr = %f \n Avg Training Loss: %f  Avg Validation loss: %f \n Avg CER: %f  Avg WER: %f'
            % (epoch + 1, params.epochs, optimizer.param_groups[0]['lr'],
               avg_cost, val_loss, val_CER, val_WER))

    print("Training done.")
    return losses


# -----------------------------------------------
"""
In this block
    criterion define
"""
CRITERION = CTCLoss()
if params.cuda and torch.cuda.is_available():
    CRITERION = CRITERION.cuda()

# -----------------------------------------------

if __name__ == "__main__":
    torch.cuda.empty_cache()
    # Initialize model
    MODEL = net_init()
    # print(MODEL)
    if params.cuda and torch.cuda.is_available():
        MODEL = MODEL.cuda()

    # Initialize optimizer
    assert params.optimizer in [
        'adadelta', 'adam', 'rmsprop', 'sgd'
    ], "Unvalid optimizer parameter '{0}'. Supported values are 'rmsprop', 'adam', 'adadelta', and 'sgd'.".format(
コード例 #7
0
crnn = crnn.cuda()


def weight_init(module):
    class_name = module.__class__.__name__
    if class_name.find('Conv') != -1:
        module.weight.data.normal_(0, 0.02)
    if class_name.find('BatchNorm') != -1:
        module.weight.data.normal_(1, 0.02)
        module.bias.data.fill_(0)


crnn.apply(weight_init)

loss_function = CTCLoss(zero_infinity=True)
loss_function = loss_function.cuda()
optimizer = Adadelta(crnn.parameters())
converter = Converter(option.alphabet)
print_every = 100
total_loss = 0.0


def validation():
    print('start validation...')
    crnn.eval()
    total_loss = 0.0
    n_correct = 0
    for i, (input, label) in enumerate(validationset_dataloader):
        if i == len(validationset_dataloader) - 1:
            continue
        if i == 9:
コード例 #8
0
ファイル: train.py プロジェクト: qq751220449/crnn_torch
def main():
    config = Config()

    if not os.path.exists(config.expr_dir):
        os.makedirs(config.expr_dir)

    if torch.cuda.is_available() and not config.use_cuda:
        print("WARNING: You have a CUDA device, so you should probably set cuda in params.py to True")

    # 加载训练数据集
    train_dataset = HubDataset(config, "train", transform=None)

    train_kwargs = {'num_workers': 2, 'pin_memory': True,
                    'collate_fn': alignCollate(config.img_height, config.img_width, config.keep_ratio)} if torch.cuda.is_available() else {}

    training_data_batch = DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True, drop_last=False, **train_kwargs)

    # 加载定长校验数据集
    eval_dataset = HubDataset(config, "eval", transform=transforms.Compose([ResizeNormalize(config.img_height, config.img_width)]))
    eval_kwargs = {'num_workers': 2, 'pin_memory': False} if torch.cuda.is_available() else {}
    eval_data_batch = DataLoader(eval_dataset, batch_size=config.eval_batch_size, shuffle=False, drop_last=False, **eval_kwargs)

    # 加载不定长校验数据集
    # eval_dataset = HubDataset(config, "eval")
    # eval_kwargs = {'num_workers': 2, 'pin_memory': False,
    #                'collate_fn': alignCollate(config.img_height, config.img_width, config.keep_ratio)} if torch.cuda.is_available() else {}
    # eval_data_batch = DataLoader(eval_dataset, batch_size=config.eval_batch_size, shuffle=False, drop_last=False, **eval_kwargs)

    # 定义网络模型
    nclass = len(config.label_classes) + 1
    crnn = CRNN(config.img_height, config.nc, nclass, config.hidden_size, n_rnn=config.n_layers)
    # 加载预训练模型
    if config.pretrained != '':
        print('loading pretrained model from %s' % config.pretrained)
        crnn.load_state_dict(torch.load(config.pretrained))
    print(crnn)

    # Compute average for `torch.Variable` and `torch.Tensor`.
    loss_avg = utils.averager()

    # Convert between str and label.
    converter = utils.strLabelConverter(config.label_classes)

    criterion = CTCLoss()           # 定义损失函数

    # 设置占位符
    image = torch.FloatTensor(config.train_batch_size, 3, config.img_height, config.img_height)
    text = torch.LongTensor(config.train_batch_size * 5)
    length = torch.LongTensor(config.train_batch_size)

    if config.use_cuda and torch.cuda.is_available():
        criterion = criterion.cuda()
        image = image.cuda()
        crnn = crnn.to(config.device)

    image = Variable(image)
    text = Variable(text)
    length = Variable(length)

    # 设定优化器
    if config.adam:
        optimizer = optim.Adam(crnn.parameters(), lr=config.lr, betas=(config.beta1, 0.999))
    elif config.adadelta:
        optimizer = optim.Adadelta(crnn.parameters())
    else:
        optimizer = optim.RMSprop(crnn.parameters(), lr=config.lr)

    def val(net, criterion, eval_data_batch):
        print('Start val')
        for p in crnn.parameters():
            p.requires_grad = False
        net.eval()

        n_correct = 0
        loss_avg_eval = utils.averager()
        for data in eval_data_batch:
            cpu_images, cpu_texts = data
            batch_size = cpu_images.size(0)
            utils.loadData(image, cpu_images)
            t, l = converter.encode(cpu_texts)
            utils.loadData(text, t)
            utils.loadData(length, l)
            preds = crnn(image)
            preds_size = Variable(torch.LongTensor([preds.size(0)] * batch_size))
            cost = criterion(preds, text, preds_size, length) / batch_size
            loss_avg_eval.add(cost)         # 计算loss

            _, preds = preds.max(2)
            preds = preds.transpose(1, 0).contiguous().view(-1)
            sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
            cpu_texts_decode = []
            for i in cpu_texts:
                cpu_texts_decode.append(i)
            for pred, target in zip(sim_preds, cpu_texts_decode):       # 计算准确率
                if pred == target:
                    n_correct += 1

            raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:config.n_val_disp]
            for raw_pred, pred, gt in zip(raw_preds, sim_preds, cpu_texts_decode):
                print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))

        accuracy = n_correct / float(len(eval_dataset))
        print('Val loss: %f, accuray: %f' % (loss_avg.val(), accuracy))

    # 训练每个batch数据
    def train(net, criterion, optimizer, data):
        cpu_images, cpu_texts = data
        batch_size = cpu_images.size(0)             # 计算当前batch_size大小
        utils.loadData(image, cpu_images)
        t, l = converter.encode(cpu_texts)          # 转换为类别
        utils.loadData(text, t)
        utils.loadData(length, l)
        optimizer.zero_grad()                       # 清零梯度
        preds = net(image)
        preds_size = Variable(torch.LongTensor([preds.size(0)] * batch_size))
        cost = criterion(preds, text, preds_size, length) / batch_size
        cost.backward()
        optimizer.step()
        return cost

    for epoch in range(config.nepoch):
        i = 0
        for batch_data in training_data_batch:
            for p in crnn.parameters():
                p.requires_grad = True
            crnn.train()
            cost = train(crnn, criterion, optimizer, batch_data)
            loss_avg.add(cost)
            i += 1

            if i % config.displayInterval == 0:
                print('[%d/%d][%d/%d] Loss: %f' %
                      (epoch, config.nepoch, i, len(training_data_batch), loss_avg.val()))
                loss_avg.reset()

            # if i % config.valInterval == 0:
            #     val(crnn, criterion, eval_data_batch)
            #
            # # do checkpointing
            # if i % config.saveInterval == 0:
            #     torch.save(crnn.state_dict(), '{0}/netCRNN_{1}_{2}.pth'.format(config.expr_dir, epoch, i))

        val(crnn, criterion, eval_data_batch)
        torch.save(crnn.state_dict(), '{0}/netCRNN_{1}_end.pth'.format(config.expr_dir, epoch))
コード例 #9
0
ファイル: train.py プロジェクト: xufuzhi/asdd123
                cls.r_adj.step()
            elif opt.lr_sch == 'N':
                pass
            else:
                raise ValueError

    image = torch.empty((opt.batchSize, 3, opt.imgH, opt.imgH),
                        dtype=torch.float32)
    text = torch.empty(opt.batchSize * 5, dtype=torch.int32)
    length = torch.empty(opt.batchSize, dtype=torch.int32)

    if opt.cuda:
        net_crnn.cuda()
        # net_crnn = torch.nn.DataParallel(net_crnn, device_ids=range(opt.ngpu))
        image = image.cuda()
        ctc_loss = ctc_loss.cuda()

    # loss Averager
    loss_avg = utils.Averager()
    # ### begin training
    iteration, total_iter = 0, len(train_loader) * opt.nepoch
    best_precision = 0
    for epoch in range(opt.nepoch):
        for i, data in enumerate(train_loader, start=1):
            iteration += 1
            for p in net_crnn.parameters():
                p.requires_grad = True
            net_crnn.train()

            # ### train one batch ################################
            cpu_images, cpu_texts = data
コード例 #10
0
ファイル: test_main.py プロジェクト: happog/FudanOCR
def main(config_yaml):
    '''
    Training/Finetune CNN_RNN_Attention Model.
    '''
    #### Load config settings. ####
    f = open(config_yaml)
    opt = yaml.load(f)
    if os.path.isdir(opt['LOGGER_PATH']) == False:
        os.mkdir(opt['LOGGER_PATH'])
    logger = Logger(opt['LOGGER_PATH'])
    if os.path.isdir(opt['SAVE_PATH']) == False:
        os.system('mkdir -p {0}'.format(opt['SAVE_PATH']))
    manualSeed = random.randint(1, 10000)
    random.seed(manualSeed)
    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    cudnn.benchmark = True

    #### Set up DataLoader. ####
    train_cfg = opt['TRAIN']
    ds_cfg = train_cfg['DATA_SOURCE']
    print('Building up dataset:{}'.format(ds_cfg['TYPE']))
    if ds_cfg['TYPE'] == 'SYN_DATA':
        text_gen = util.TextGenerator(ds_cfg['GEN_SET'], ds_cfg['GEN_LEN'])
        ds_train = dataset.synthDataset(ds_cfg['FONT_ROOT'],
                                        ds_cfg['FONT_SIZE'], text_gen)
    elif ds_cfg['TYPE'] == 'IMG_DATA':
        ds_train = dataset.trainDataset(
            ds_cfg['IMG_ROOT'], ds_cfg['TRAIN_SET'],
            transform=None)  #dataset.graybackNormalize()
    assert ds_train
    train_loader = torch.utils.data.DataLoader(
        ds_train,
        batch_size=train_cfg['BATCH_SIZE'],
        shuffle=True,
        sampler=None,
        num_workers=opt['WORKERS'],
        collate_fn=dataset.alignCollate(imgH=train_cfg['IMG_H'],
                                        imgW=train_cfg['MAX_W']))

    val_cfg = opt['VALIDATION']
    ds_val = dataset.testDataset(val_cfg['IMG_ROOT'],
                                 val_cfg['VAL_SET'],
                                 transform=None)  #dataset.graybackNormalize()
    assert ds_val
    val_loader = torch.utils.data.DataLoader(ds_val,
                                             batch_size=16,
                                             shuffle=False,
                                             num_workers=opt['WORKERS'],
                                             collate_fn=dataset.alignCollate(
                                                 imgH=train_cfg['IMG_H'],
                                                 imgW=train_cfg['MAX_W']))

    #### Model construction and Initialization. ####
    alphabet = keys.alphabet
    nClass = len(alphabet) + 1

    if opt['N_GPU'] > 1:
        opt['RNN']['multi_gpu'] = True
    else:
        opt['RNN']['multi_gpu'] = False
    model = crann.CRANN(opt, nClass)
    #print(model)

    #### Train/Val the model. ####
    converter = util.strLabelConverter(alphabet)
    criterion = CTCLoss()
    if opt['CUDA']:
        model.cuda()
        criterion.cuda()

    if opt['OPTIMIZER'] == 'RMSprop':
        optimizer = optim.RMSprop(model.parameters(), lr=opt['TRAIN']['LR'])
    elif opt['OPTIMIZER'] == 'Adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=opt['TRAIN']['LR'],
                               betas=(opt['TRAIN']['BETA1'], 0.999))
    elif opt['OPTIMIZER'] == 'SGD':
        optimizer = optim.SGD(model.parameters(), lr=opt['TRAIN']['LR'])
    else:
        optimizer = optim.Adadelta(model.parameters(), lr=opt['TRAIN']['LR'])

    start_epoch = 0
    if opt['VAL_ONLY']:
        print('=>loading pretrained model from %s for val only.' %
              opt['CRANN'])
        checkpoint = torch.load(opt['CRANN'])
        model.load_state_dict(checkpoint['state_dict'])
        val(model, val_loader, criterion, converter, 0, 0, logger, True)
    elif opt['FINETUNE']:
        print('=>loading pretrained model from %s for finetuen.' %
              opt['CRANN'])
        checkpoint = torch.load(opt['CRANN'])
        #model.load_state_dict(checkpoint['state_dict'])
        model_dict = model.state_dict()
        #print(model_dict.keys())
        cnn_dict = {
            "cnn." + k: v
            for k, v in checkpoint.items() if "cnn." + k in model_dict
        }
        model_dict.update(cnn_dict)
        model.load_state_dict(model_dict)
        for epoch in range(start_epoch, opt['EPOCHS']):
            adjust_lr(optimizer, opt['TRAIN']['LR'], epoch, opt['STEP'])
            train(model, train_loader, val_loader, criterion, optimizer, opt,
                  converter, epoch, logger)
    elif opt['RESUME']:
        print('=>loading checkpoint from %s for resume training.' %
              opt['CRANN'])
        checkpoint = torch.load(opt['CRANN'])
        start_epoch = checkpoint['epoch'] + 1
        print('resume from epoch:{}'.format(start_epoch))
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        for epoch in range(start_epoch, opt['EPOCHS']):
            adjust_lr(optimizer, opt['TRAIN']['LR'], epoch, opt['STEP'])
            train(model, train_loader, val_loader, criterion, optimizer, opt,
                  converter, epoch, logger)
    else:
        print('train from scratch.')
        for epoch in range(start_epoch, opt['EPOCHS']):
            adjust_lr(optimizer, opt['TRAIN']['LR'], epoch, opt['STEP'])
            train(model, train_loader, val_loader, criterion, optimizer, opt,
                  converter, epoch, logger)