示例#1
0
def InitStandardRD(wandb=None,
                   nlr=1e-4,
                   rd_low_loss_learn=False,
                   load_model=load_model_name):
    logger.info('Loading dataset')

    # load CNN
    logger.info('Preparing Net')
    net = CRNN(1, len(classes), 256)
    if (wandb is not None):
        wandb.watch(net)
    #net = HTRNet(cnn_cfg, rnn_cfg, len(classes))#

    loss = torch.nn.CTCLoss()
    net_parameters = net.parameters()

    optimizer = torch.optim.Adam(net_parameters, nlr, weight_decay=0.00005)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer,
        [int(.5 * max_epochs), int(.75 * max_epochs)])

    logger.info('Initializing Reading Discriminator')
    rd = ReadingDiscriminator(optimizer,
                              net,
                              loss,
                              1e-4,
                              load_model_full_path=load_model,
                              rd_low_loss_learn=rd_low_loss_learn)
    return rd, scheduler
示例#2
0
def load_model(lexicon,
               seq_proj=[0, 0],
               backend='resnet18',
               base_model_dir=None,
               snapshot=None,
               cuda=True,
               do_beam_search=False,
               dropout_conv=False,
               dropout_rnn=False,
               dropout_output=False,
               do_ema=False,
               ada_after_rnn=False,
               ada_before_rnn=False,
               rnn_hidden_size=128):
    net = CRNN(lexicon=lexicon,
               seq_proj=seq_proj,
               backend=backend,
               base_model_dir=base_model_dir,
               do_beam_search=do_beam_search,
               dropout_conv=dropout_conv,
               dropout_rnn=dropout_rnn,
               dropout_output=dropout_output,
               do_ema=do_ema,
               ada_after_rnn=ada_after_rnn,
               ada_before_rnn=ada_before_rnn,
               rnn_hidden_size=rnn_hidden_size)
    #net = nn.DataParallel(net)
    if snapshot is not None:
        print('snapshot is: {}'.format(snapshot))
        load_weights(net, torch.load(snapshot))
    if cuda:
        print('setting network on gpu')
        net = net.cuda()
        print('set network on gpu')
    return net
示例#3
0
def main(data_path, abc, seq_proj, backend, snapshot, input_size, gpu,
         visualize):
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    cuda = True if gpu is not '' else False
    abc = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
    #print(abc)
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu
    input_size = [int(x) for x in input_size.split('x')]
    transform = Compose(
        [Rotation(), Resize(size=(input_size[0], input_size[1]))])
    if data_path is not None:

        data = LoadDataset(data_path=data_path,
                           mode="test",
                           transform=transform)

    seq_proj = [int(x) for x in seq_proj.split('x')]

    #net = load_model(abc, seq_proj, backend, snapshot, cuda)
    net = CRNN(abc=abc, seq_proj=seq_proj, backend=backend)
    #net = nn.DataParallel(net)
    if snapshot is not None:
        load_weights(net, torch.load(snapshot))
    if cuda:
        net = net.cuda()
    #import pdb;pdb.set_trace()
    net = net.eval()
    detect(net, data, cuda, visualize)
示例#4
0
    def __init__(self, args):
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
        self.args = args
        self.alphabet = alphabetChinese
        nclass = len(self.alphabet) + 1
        nc = 1
        self.net = CRNN(args.imgH, nc, args.nh, nclass)
        self.converter = utils.strLabelConverter(self.alphabet, ignore_case=False)
        self.transformer = resizeNormalize(args.imgH)

        print('loading pretrained model from %s' % args.model_path)
        checkpoint = torch.load(args.model_path)
        if 'model_state_dict' in checkpoint.keys():
            checkpoint = checkpoint['model_state_dict']
        from collections import OrderedDict
        model_dict = OrderedDict()
        for k, v in checkpoint.items():
            if 'module' in k:
                model_dict[k[7:]] = v
            else:
                model_dict[k] = v
        self.net.load_state_dict(model_dict)

        if args.cuda and torch.cuda.is_available():
            print('available gpus is,', torch.cuda.device_count())
            self.net = torch.nn.DataParallel(self.net, output_dim=1).cuda()
        
        self.net.eval()
示例#5
0
def save_checkpoint(model, optimizer, learning_rate, iteration, filepath):
    print("Saving model and optimizer state at iteration {} to {}".format(
        iteration, filepath))
    model_for_saving = CRNN(**CRNN_config).cuda()
    model_for_saving.load_state_dict(model.state_dict())
    torch.save(
        {
            'model': model_for_saving,
            'iteration': iteration,
            'optimizer': optimizer.state_dict(),
            'learning_rate': learning_rate
        }, filepath)
def load_model(abc,
               seq_proj=[0, 0],
               backend='resnet18',
               snapshot=None,
               cuda=True):
    net = CRNN(abc=abc, seq_proj=seq_proj, backend=backend)
    net = nn.DataParallel(net)
    if snapshot is not None:
        load_weights(net, torch.load(snapshot))
    if cuda:
        net = net.cuda()
    return net
示例#7
0
    def __init__(self):
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
        if args.chars_file == '':
            self.alphabet = alphabetChinese
        else:
            self.alphabet = utils.load_chars(args.chars_file)
        nclass = len(self.alphabet) + 1
        nc = 1
        self.net = CRNN(args.imgH, nc, args.nh, nclass)
        self.train_dataloader, self.val_dataloader = self.dataloader(
            self.alphabet)
        self.criterion = CTCLoss()
        self.optimizer = self.get_optimizer()
        self.converter = utils.strLabelConverter(self.alphabet,
                                                 ignore_case=False)
        self.best_acc = 0.00001

        model_name = '%s' % (args.dataset_name)
        if not os.path.exists(args.save_prefix):
            os.mkdir(args.save_prefix)
        args.save_prefix += model_name

        if args.pretrained != '':
            print('loading pretrained model from %s' % args.pretrained)
            checkpoint = torch.load(args.pretrained)

            if 'model_state_dict' in checkpoint.keys():
                # self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                args.start_epoch = checkpoint['epoch']
                self.best_acc = checkpoint['best_acc']
                checkpoint = checkpoint['model_state_dict']

            from collections import OrderedDict
            model_dict = OrderedDict()
            for k, v in checkpoint.items():
                if 'module' in k:
                    model_dict[k[7:]] = v
                else:
                    model_dict[k] = v
            self.net.load_state_dict(model_dict)

        if not args.cuda and torch.cuda.is_available():
            print(
                "WARNING: You have a CUDA device, so you should probably run with --cuda"
            )

        elif args.cuda and torch.cuda.is_available():
            print('available gpus is ', torch.cuda.device_count())
            self.net = torch.nn.DataParallel(self.net, output_dim=1).cuda()
            self.criterion = self.criterion.cuda()
示例#8
0
def train(output_directory, epochs, learning_rate, iters_per_checkpoint,
          batch_size, seed, checkpoint_path):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    criterion = torch.nn.L1Loss()
    model = CRNN(**CRNN_config).cuda()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # load checkpoint if one exists
    iteration = 0
    if checkpoint_path != "":
        model, optimizer, iteration = load_checkpoint(checkpoint_path, model,
                                                      optimizer)

        iteration += 1

    trainset = LJspeechDataset(**data_config)
    # my_collate = collate_fn(trainset)
    train_loader = DataLoader(trainset, num_workers=1, shuffle=True,\
                                batch_size=batch_size,
                                collate_fn=collate_fn,
                                pin_memory=False,
                                drop_last=True)

    if not os.path.isdir(output_directory):
        os.makedirs(output_directory)
        os.chmod(output_directory, 0o775)
    print("output directory", output_directory)

    model.train()
    epoch_offset = max(0, int(iteration / len(train_loader)))

    for epoch in range(epoch_offset, epochs):
        epoch_ave_loss = 0
        print("Epoch: {}".format(epoch))
        for i, batch in enumerate(train_loader):
            model.zero_grad()

            # zeroPadded_batch = pad_sequence(batch, batch_first=True)

            netFeed = batch[:, :-1, :]
            netTarget = batch[:, 1:, :]
            netTarget = torch.autograd.Variable(netTarget.cuda())
            netFeed = torch.autograd.Variable(netFeed.cuda())

            netOutput = model(netFeed)

            loss = criterion(netOutput, netTarget)

            reduced_loss = loss.item()

            loss.backward()

            optimizer.step()

            if (iteration % iters_per_checkpoint == 0):
                print("{}:\t{:.9f}".format(iteration, reduced_loss))
            iteration += 1
            epoch_ave_loss += reduced_loss

        checkpoint_path = "{}/CRNN_net_{}".format(output_directory, epoch)
        save_checkpoint(model, optimizer, learning_rate, iteration,
                        checkpoint_path)
        epoch_ave_loss = epoch_ave_loss / i
        print("Epoch: {}, the average epoch loss: {}".format(
            epoch, epoch_ave_loss))
示例#9
0
def main():
    # Set parameters of the trainer
    global args, device
    args = parse_args()

    print('=' * 60)
    print(args)
    print('=' * 60)
    random.seed(args.manual_seed)
    np.random.seed(args.manual_seed)
    torch.manual_seed(args.manual_seed)

    if torch.cuda.is_available() and not args.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )
    if args.cuda and torch.cuda.is_available():
        device = torch.device('cuda')
        #cudnn.benchmark = True
    else:
        device = torch.device('cpu')

    # load alphabet from file
    if os.path.isfile(args.alphabet):
        alphabet = ''
        with open(args.alphabet, mode='rb') as f:
            for line in f.readlines():
                alphabet += line.decode('utf-8')[0]
        args.alphabet = alphabet

    converter = utils.CTCLabelConverter(args.alphabet, ignore_case=False)

    # data loader
    image_size = (args.image_h, args.image_w)
    collater = DatasetCollater(image_size, keep_ratio=args.keep_ratio)
    train_dataset = Dataset(mode='train',
                            data_root=args.data_root,
                            transform=None)
    #sampler = RandomSequentialSampler(train_dataset, args.batch_size)
    train_loader = data.DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   collate_fn=collater,
                                   shuffle=True,
                                   num_workers=args.workers)

    val_dataset = Dataset(mode='val', data_root=args.data_root, transform=None)
    val_loader = data.DataLoader(val_dataset,
                                 batch_size=args.batch_size,
                                 collate_fn=collater,
                                 shuffle=True,
                                 num_workers=args.workers)

    # network
    num_classes = len(args.alphabet) + 1
    num_channels = 1
    if args.arch == 'crnn':
        model = CRNN(args.image_h, num_channels, num_classes, args.num_hidden)
    elif args.arch == 'densenet':
        model = DenseNet(
            num_channels=num_channels,
            num_classes=num_classes,
            growth_rate=12,
            block_config=(3, 6, 9),  #(3,6,12,16),
            compression=0.5,
            num_init_features=64,
            bn_size=4,
            rnn=args.rnn,
            num_hidden=args.num_hidden,
            drop_rate=0,
            small_inputs=True,
            efficient=False)
    else:
        raise ValueError('unknown architecture {}'.format(args.arch))
    model = model.to(device)
    summary(model, torch.zeros((2, 1, 32, 650)).to(device))
    #print('='*60)
    #print(model)
    #print('='*60)

    # loss
    criterion = CTCLoss()
    criterion = criterion.to(device)

    # setup optimizer
    if args.optimizer == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=0.9,
                              weight_decay=args.weight_decay)
    elif args.optimizer == 'rmsprop':
        optimizer = optim.RMSprop(model.parameters(),
                                  lr=args.lr,
                                  weight_decay=args.weight_decay)
    elif args.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               betas=(0.9, 0.999),
                               weight_decay=args.weight_decay)
    elif args.optimizer == 'adadelta':
        optimizer = optim.Adadelta(model.parameters(),
                                   weight_decay=args.weight_decay)
    else:
        raise ValueError('unknown optimizer {}'.format(args.optimizer))
    print('=' * 60)
    print(optimizer)
    print('=' * 60)

    # Define learning rate decay schedule
    global scheduler
    #exp_decay = math.exp(-0.1)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer,
                                                 gamma=args.decay_rate)
    #step_size = 10000
    #gamma_decay = 0.8
    #scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma_decay)
    #scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=gamma_decay)

    # initialize model
    if args.pretrained and os.path.isfile(args.pretrained):
        print(">> Using pre-trained model '{}'".format(
            os.path.basename(args.pretrained)))
        state_dict = torch.load(args.pretrained)
        model.load_state_dict(state_dict)
        print("loading pretrained model done.")

    global is_best, best_accuracy
    is_best = False
    best_accuracy = 0.0
    start_epoch = 0
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            # load checkpoint weights and update model and optimizer
            print(">> Loading checkpoint:\n>> '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            best_accuracy = checkpoint['best_accuracy']
            print(">>>> loaded checkpoint:\n>>>> '{}' (epoch {})".format(
                args.resume, start_epoch))
            model.load_state_dict(checkpoint['state_dict'])
            #optimizer.load_state_dict(checkpoint['optimizer'])
            # important not to forget scheduler updating
            #scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.decay_rate, last_epoch=start_epoch - 1)
        else:
            print(">> No checkpoint found at '{}'".format(args.resume))

    # Create export dir if it doesnt exist
    checkpoint = "{}".format(args.arch)
    checkpoint += "_{}".format(args.optimizer)
    checkpoint += "_lr_{}".format(args.lr)
    checkpoint += "_decay_rate_{}".format(args.decay_rate)
    checkpoint += "_bsize_{}".format(args.batch_size)
    checkpoint += "_height_{}".format(args.image_h)
    checkpoint += "_keep_ratio" if args.keep_ratio else "_width_{}".format(
        image_size[1])

    args.checkpoint = os.path.join(args.checkpoint, checkpoint)
    if not os.path.exists(args.checkpoint):
        os.makedirs(args.checkpoint)

    print('start training...')
    for epoch in range(start_epoch, args.max_epoch):
        # Aujust learning rate for each epoch
        scheduler.step()

        # Train for one epoch on train set
        _ = train(train_loader, val_loader, model, criterion, optimizer, epoch,
                  converter)
def main():

    print(torch.__version__)

    with open('config.yaml') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    print(torch.cuda.is_available())
    torch.backends.cudnn.benchmark = True

    char_set = config['char_set']
    # if config['method'] == 'ctc':
    char2idx_ctc, idx2char_ctc = get_char_dict_ctc(char_set)
    char2idx_att, idx2char_att = get_char_dict_attention(char_set)
    config['char2idx_ctc'] = char2idx_ctc
    config['idx2char_ctc'] = idx2char_ctc
    config['char2idx_att'] = char2idx_att
    config['idx2char_att'] = idx2char_att

    batch_size = config['batch_size']

    if not os.path.exists(config['save_path']):
        os.mkdir(config['save_path'])
    print(config)

    train_dataset = TextRecDataset(config, phase='train')
    val_dataset = TextRecDataset(config, phase='val')
    test_dataset = TextRecDataset(config, phase='test')
    trainloader = data.DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=cpu_count(),
                                  pin_memory=False)

    valloader = data.DataLoader(val_dataset,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=cpu_count(),
                                pin_memory=False)

    testloader = data.DataLoader(test_dataset,
                                 batch_size=batch_size,
                                 shuffle=False,
                                 num_workers=cpu_count(),
                                 pin_memory=False)

    class_num = len(config['char_set']) + 1
    print('class_num', class_num)
    model = CRNN(class_num)
    # decoder = Decoder(class_num, config['max_string_len'], char2idx_att)
    attention_head = AttentionHead(class_num, config['max_string_len'], char2idx_att)

    # criterion = nn.CTCLoss(blank=char2idx['-'], reduction='mean')
    criterion_ctc = CTCFocalLoss(blank=char2idx_ctc['-'], gamma=0.5)
    criterion_att = nn.CrossEntropyLoss(reduction='none')

    if config['use_gpu']:
        model = model.cuda()
        # decoder = decoder.cuda()
        attention_head = attention_head.cuda()
    summary(model, (1, 32, 400))

    # model = torch.nn.DataParallel(model)

    # optimizer = torch.optim.Adam([{'params': model.parameters()}], lr=1e-2, weight_decay=5e-4)
    optimizer = torch.optim.SGD([{'params': model.parameters()},
                                 {'params': attention_head.parameters()}], lr=0.001, momentum=0.9, weight_decay=5e-4)

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[500, 800], gamma=0.1)

    print('train start, total batches %d' % len(trainloader))
    iter_cnt = 0
    for i in range(1, config['epochs']+1):
        start = time.time()
        model.train()
        attention_head.train()
        for j, batch in enumerate(trainloader):

            iter_cnt += 1
            imgs = batch[0].cuda()
            labels_length = batch[1].cuda()
            labels_str = batch[2]
            labels_ctc = batch[3].cuda().long()
            labels_ctc_mask = batch[4].cuda().float()
            labels_att = batch[5].cuda().long()
            labels_att_mask = batch[6].cuda().float()

            if config['method'] == 'ctc':
                # CTC loss
                outputs, cnn_features = model(imgs)
                log_prob = outputs.log_softmax(dim=2)
                t,n,c = log_prob.size(0),log_prob.size(1),log_prob.size(2)
                input_length = (torch.ones((n,)) * t).cuda().int()
                loss_ctc = criterion_ctc(log_prob, labels_ctc, input_length, labels_length)

                # attention loss   
                outputs = attention_head(cnn_features, labels_att)
                probs = outputs.permute(1, 2, 0)
                losses_att = criterion_att(probs, labels_att)
                losses_att = losses_att * labels_att_mask
                losses_att = losses_att.sum() / labels_att_mask.sum()

                loss = loss_ctc + losses_att

            else:
                # cross_entropy loss
                outputs_ctc, sqs = model(imgs)
                outputs_att = decoder(sqs, label_att)

                outputs = outputs_att.permute(1, 2, 0)
                losses = criterion(outputs, label_att)
                losses = losses * labels_att_mask
                loss = losses.sum() / labels_att_mask.sum()
 
                # attention loss   

            optimizer.zero_grad()            
            loss.backward()
            # nn.utils.clip_grad_norm_(model.parameters(), max_norm=5)
            optimizer.step()

            if iter_cnt % config['print_freq'] == 0:
                print('epoch %d, iter %d, train loss %f' % (i, iter_cnt, loss.item()))

        print('epoch %d, time %f' % (i, (time.time() - start)))
        scheduler.step()

        print("validating...")
        
        if config['method'] == 'ctc':
            eval_ctc(model, valloader, idx2char_ctc)
        else:
            eval_attention(model, decoder, valloader, idx2char_att)

        if i % config['test_freq'] == 0:
            print("testing...")
            if config['method'] == 'ctc':
                line_acc, rec_score = eval_ctc(model, testloader, idx2char_ctc)
            else:
                line_acc, rec_score = eval_attention(model, decoder, testloader, idx2char_att)

        if i % config['save_freq'] == 0:
            save_file_name = f"epoch_{i}_acc_{line_acc:.3f}_rec_score_{rec_score:.3f}.pth"
            save_file = os.path.join(config['save_path'], save_file_name)
            torch.save(model.state_dict(), save_file)
示例#11
0
def main(arg):
    print(arg)
    train_dataset = dataset.lmdbDataset(
        path=arg.train_root,
        # transform=dataset.resizeNormalize((imgW,imgH)),
    )
    test_dataset = dataset.lmdbDataset(
        path=arg.test_root,
        # transform=dataset.resizeNormalize((arg.imgW,arg.imgH)),
    )
    d = test_dataset.__getitem__(0)
    l = test_dataset.__len__()
    train_loader = DataLoader(train_dataset,
                              num_workers=arg.num_workers,
                              batch_size=arg.batch_size,
                              collate_fn=dataset.alignCollate(
                                  imgH=arg.imgH,
                                  imgW=arg.imgW,
                                  keep_ratio=arg.keep_ratio),
                              shuffle=True,
                              drop_last=True)

    criterion = CTCLoss()
    converter = utils.Converter(arg.num_class)
    crnn = CRNN(imgH=arg.imgH, nc=3, nclass=arg.num_class + 1, nh=256)

    # custom weights initialization called on crnn
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            m.weight.data.normal_(0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)

    crnn.apply(weights_init)
    print(crnn)

    image = torch.FloatTensor(arg.batch_size, 3, arg.imgH, arg.imgW)
    text = torch.IntTensor(arg.batch_size * 5)
    length = torch.IntTensor(arg.batch_size)

    image = Variable(image)
    text = Variable(text)
    length = Variable(length)

    # loss averager
    loss_avg = utils.averager()

    # setup optimizer
    if arg.opt == 'adam':
        optimizer = optim.Adam(crnn.parameters(), 0.01, betas=(0.5, 0.999))
    elif arg.opt == 'adadelta':
        optimizer = optim.Adadelta(crnn.parameters())
    else:
        optimizer = optim.RMSprop(crnn.parameters(), 0.01)

    for epoch in range(arg.n_epoch):
        train_iter = iter(train_loader)
        i = 0
        while i < len(train_loader):
            for p in crnn.parameters():
                p.requires_grad = True
            crnn.train()

            data = train_iter.next()
            cpu_images, cpu_texts = data
            batch_size = cpu_images.size(0)
            utils.loadData(image, cpu_images)
            text_labels, l = converter.encode(cpu_texts)
            utils.loadData(text, text_labels)
            utils.loadData(length, l)

            preds = crnn(image)
            preds_size = Variable(torch.IntTensor([preds.size(0)] *
                                                  batch_size))
            cost = criterion(preds, text, preds_size, length) / batch_size
            crnn.zero_grad()
            cost.backward()
            optimizer.step()

            loss_avg.add(cost)
            i += 1

            if i % arg.displayInterval == 0:
                print(
                    '[%d/%d][%d/%d] Loss: %f' %
                    (epoch, arg.n_epoch, i, len(train_loader), loss_avg.val()))
                loss_avg.reset()

            if i % arg.testInterval == 0:
                test(arg, crnn, test_dataset, criterion, image, text, length)

            # do checkpointing
            if i % arg.saveInterval == 0:
                name = '{0}/netCRNN_{1}_{2}_{3}_{4}.pth'.format(
                    arg.model_dir, arg.num_class, arg.type, epoch, i)
                torch.save(crnn.state_dict(), name)
                print('model saved at ', name)
    torch.save(
        crnn.state_dict(),
        '{0}/netCRNN_{1}_{2}.pth'.format(arg.model_dir, arg.num_class,
                                         arg.type))
    # Get converter, transformer,
    if not os.path.exists(opt.expr_dir):
        os.makedirs(opt.expr_dir)

    random.seed(opt.manualSeed)
    np.random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)

    transformer = dataset.resizeNormalize((100, 32))

    nclass = len(opt.alphabet) + 1
    nc = 1

    converter = misc.strLabelConverter(opt.alphabet)

    crnn = CRNN(opt.imgH, nc, nclass, opt.nh)
    if opt.pretrained != '':
        print('loading pretrained model from %s' % opt.pretrained)
        crnn.load_state_dict(load_multi(opt.pretrained), strict=False)

    # Process pruned conv2d and batchnorm2d, store them in a dictionary
    crnn_l = list(crnn.cnn._modules.items())

    last_channels = [0]

    crnn_new = CRNN(opt.imgH, nc, nclass, opt.nh)
    crnn_new = copy.deepcopy(crnn)
    new_dict = {}

    for i in range(len(crnn_l)):
        module = crnn_l[i][1]
示例#13
0
from warpctc_pytorch import CTCLoss
from models.resnet_crnn import ResNetCRNN
from models.crnn import CRNN
import torch
from torch.utils.checkpoint import checkpoint
from time import time

r = CRNN(128, 3, 9100, 512).cuda()
input = torch.ones(size=(4, 3, 128, 300000)).cuda()
c = CTCLoss().cuda()
# input.requires_grad =True

cp = True
try:
    ok = False
    try:
        start = time()
        output = r(input, cp=cp)
        l = torch.ones(size=(8, 30)).int()
        lsize = torch.IntTensor([30] * 8)
        psize = torch.IntTensor([output.size(0)] * output.size(1))
        if cp:
            loss = checkpoint(c, output, l.view(-1), psize, lsize)
        else:
            loss = c(output, l.view(-1), psize, lsize)
        loss.backward()
        print('cost', time() - start)
        ok = True
    except:
        raise RuntimeError("RE")
    finally:
示例#14
0
                              imgW=arg.imgW,
                              mean=mean,
                              std=std)
    train_loader = DataLoader(train_dataset,
                              batch_size=arg.batch_size,
                              num_workers=arg.num_workers,
                              shuffle=True)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=arg.batch_size,
                              num_workers=arg.num_workers,
                              shuffle=True)
    nc = 1
    num_class = params['num_class'] + 1
    converter = utils.ConverterV2(params['alphabets'])
    criterion = torch.nn.CTCLoss(reduction='sum')
    crnn = CRNN(32, nc, num_class, 256)
    crnn.apply(weights_init)

    # setup optimizer
    if arg.opt == 'adam':
        optimizer = optim.Adam(crnn.parameters(),
                               lr=arg.lr,
                               betas=(0.99, 0.999))
    elif arg.opt == 'adadelta':
        optimizer = optim.Adadelta(crnn.parameters())
    else:
        optimizer = optim.RMSprop(crnn.parameters(), lr=arg.lr)

    crnn.register_backward_hook(backward_hook)

    main(crnn, train_loader, valid_loader, criterion, optimizer)
示例#15
0
class Trainer(object):
    def __init__(self):
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
        if args.chars_file == '':
            self.alphabet = alphabetChinese
        else:
            self.alphabet = utils.load_chars(args.chars_file)
        nclass = len(self.alphabet) + 1
        nc = 1
        self.net = CRNN(args.imgH, nc, args.nh, nclass)
        self.train_dataloader, self.val_dataloader = self.dataloader(
            self.alphabet)
        self.criterion = CTCLoss()
        self.optimizer = self.get_optimizer()
        self.converter = utils.strLabelConverter(self.alphabet,
                                                 ignore_case=False)
        self.best_acc = 0.00001

        model_name = '%s' % (args.dataset_name)
        if not os.path.exists(args.save_prefix):
            os.mkdir(args.save_prefix)
        args.save_prefix += model_name

        if args.pretrained != '':
            print('loading pretrained model from %s' % args.pretrained)
            checkpoint = torch.load(args.pretrained)

            if 'model_state_dict' in checkpoint.keys():
                # self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                args.start_epoch = checkpoint['epoch']
                self.best_acc = checkpoint['best_acc']
                checkpoint = checkpoint['model_state_dict']

            from collections import OrderedDict
            model_dict = OrderedDict()
            for k, v in checkpoint.items():
                if 'module' in k:
                    model_dict[k[7:]] = v
                else:
                    model_dict[k] = v
            self.net.load_state_dict(model_dict)

        if not args.cuda and torch.cuda.is_available():
            print(
                "WARNING: You have a CUDA device, so you should probably run with --cuda"
            )

        elif args.cuda and torch.cuda.is_available():
            print('available gpus is ', torch.cuda.device_count())
            self.net = torch.nn.DataParallel(self.net, output_dim=1).cuda()
            self.criterion = self.criterion.cuda()

    def dataloader(self, alphabet):
        # train_transform = transforms.Compose(
        #     [transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
        #     resizeNormalize(args.imgH)])
        # train_dataset = BaseDataset(args.train_dir, alphabet, transform=train_transform)
        train_dataset = NumDataset(args.train_dir,
                                   alphabet,
                                   transform=resizeNormalize(args.imgH))
        train_dataloader = DataLoader(dataset=train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)

        if os.path.exists(args.val_dir):
            # val_dataset = BaseDataset(args.val_dir, alphabet, transform=resizeNormalize(args.imgH))
            val_dataset = NumDataset(args.val_dir,
                                     alphabet,
                                     mode='test',
                                     transform=resizeNormalize(args.imgH))
            val_dataloader = DataLoader(dataset=val_dataset,
                                        batch_size=args.batch_size,
                                        shuffle=False,
                                        num_workers=args.num_workers,
                                        pin_memory=True)
        else:
            val_dataloader = None

        return train_dataloader, val_dataloader

    def get_optimizer(self):
        if args.optimizer == 'sgd':
            optimizer = optim.SGD(
                self.net.parameters(),
                lr=args.lr,
                momentum=args.momentum,
                weight_decay=args.wd,
            )
        elif args.optimizer == 'adam':
            optimizer = optim.Adam(
                self.net.parameters(),
                lr=args.lr,
                betas=(args.beta1, 0.999),
            )
        else:
            optimizer = optim.RMSprop(
                self.net.parameters(),
                lr=args.lr,
                momentum=args.momentum,
                weight_decay=args.wd,
            )
        return optimizer

    def train(self):
        logging.basicConfig()
        logger = logging.getLogger()
        logger.setLevel(logging.INFO)
        log_file_path = args.save_prefix + '_train.log'
        log_dir = os.path.dirname(log_file_path)
        if log_dir and not os.path.exists(log_dir):
            os.mkdir(log_dir)
        fh = logging.FileHandler(log_file_path)
        logger.addHandler(fh)
        logger.info(args)
        logger.info('Start training from [Epoch {}]'.format(args.start_epoch +
                                                            1))

        losses = utils.Averager()
        train_accuracy = utils.Averager()

        for epoch in range(args.start_epoch, args.nepoch):
            self.net.train()
            btic = time.time()
            for i, (imgs, labels) in enumerate(self.train_dataloader):
                batch_size = imgs.size()[0]
                imgs = imgs.cuda()
                preds = self.net(imgs).cpu()
                text, length = self.converter.encode(
                    labels
                )  # length  一个batch各个样本的字符长度, text 一个batch中所有中文字符所对应的下标
                preds_size = torch.IntTensor([preds.size(0)] * batch_size)
                loss_avg = self.criterion(preds, text, preds_size,
                                          length) / batch_size

                self.optimizer.zero_grad()
                loss_avg.backward()
                self.optimizer.step()

                losses.update(loss_avg.item(), batch_size)

                _, preds_m = preds.max(2)
                preds_m = preds_m.transpose(1, 0).contiguous().view(-1)
                sim_preds = self.converter.decode(preds_m.data,
                                                  preds_size.data,
                                                  raw=False)
                n_correct = 0
                for pred, target in zip(sim_preds, labels):
                    if pred == target:
                        n_correct += 1
                train_accuracy.update(n_correct, batch_size, MUL_n=False)

                if args.log_interval and not (i + 1) % args.log_interval:
                    logger.info(
                        '[Epoch {}/{}][Batch {}/{}], Speed: {:.3f} samples/sec, Loss:{:.3f}'
                        .format(epoch + 1, args.nepoch, i + 1,
                                len(self.train_dataloader),
                                batch_size / (time.time() - btic),
                                losses.val()))
                    losses.reset()

            logger.info(
                'Training accuracy: {:.3f}, [#correct:{} / #total:{}]'.format(
                    train_accuracy.val(), train_accuracy.sum,
                    train_accuracy.count))
            train_accuracy.reset()

            if args.val_interval and not (epoch + 1) % args.val_interval:
                acc = self.validate(logger)
                if acc > self.best_acc:
                    self.best_acc = acc
                    save_path = '{:s}_best.pth'.format(args.save_prefix)
                    torch.save(
                        {
                            'epoch': epoch,
                            'model_state_dict': self.net.state_dict(),
                            # 'optimizer_state_dict': self.optimizer.state_dict(),
                            'best_acc': self.best_acc,
                        },
                        save_path)
                logging.info("best acc is:{:.3f}".format(self.best_acc))
                if args.save_interval and not (epoch + 1) % args.save_interval:
                    save_path = '{:s}_{:04d}_{:.3f}.pth'.format(
                        args.save_prefix, epoch + 1, acc)
                    torch.save(
                        {
                            'epoch': epoch,
                            'model_state_dict': self.net.state_dict(),
                            # 'optimizer_state_dict': self.optimizer.state_dict(),
                            'best_acc': self.best_acc,
                        },
                        save_path)

    def validate(self, logger):
        if self.val_dataloader is None:
            return 0
        logger.info('Start validate.')
        losses = utils.Averager()
        self.net.eval()
        n_correct = 0
        with torch.no_grad():
            for i, (imgs, labels) in enumerate(self.val_dataloader):
                batch_size = imgs.size()[0]
                imgs = imgs.cuda()
                preds = self.net(imgs).cpu()
                text, length = self.converter.encode(
                    labels
                )  # length  一个batch各个样本的字符长度, text 一个batch中所有中文字符所对应的下标
                preds_size = torch.IntTensor(
                    [preds.size(0)] * batch_size)  # timestep * batchsize
                loss_avg = self.criterion(preds, text, preds_size,
                                          length) / batch_size

                losses.update(loss_avg.item(), batch_size)

                _, preds = preds.max(2)
                preds = preds.transpose(1, 0).contiguous().view(-1)
                sim_preds = self.converter.decode(preds.data,
                                                  preds_size.data,
                                                  raw=False)
                for pred, target in zip(sim_preds, labels):
                    if pred == target:
                        n_correct += 1

        accuracy = n_correct / float(losses.count)

        logger.info(
            'Evaling loss: {:.3f}, accuracy: {:.3f}, [#correct:{} / #total:{}]'
            .format(losses.val(), accuracy, n_correct, losses.count))

        return accuracy
示例#16
0
    num_classes = 63  # 62 characters + 1 blank for ctc
    batch_size = 64
    seed = 23
    annotations_path = "data/annotations.csv"
    images_path = "data/word_images/data/"
    checkpoint_path = "checkpoints/crnn/"
    data = DataLoader(batch_size,
                      annotations_path=annotations_path,
                      images_path=images_path,
                      seed=seed)

    test, test_steps = data.load_text_data(type='test')

    weights_path = 'checkpoints/crnn_best_weights/crnn_best.h5'

    crnn = CRNN(num_classes, batch_size)
    crnn.load_weights(weights_path)

    results = crnn.evaluate_generator(test, steps=test_steps, verbose=1)

# =============================================================================
#     # TESTING
#     def get_seq_len(data):
#         def py_get_seq_len(y_pred):
#             seq_lens = [y_pred.shape[1]]*y_pred.shape[0]
#             return [seq_lens]
#         return tf.py_function(py_get_seq_len, [data], tf.int32)
#     def index_to_label(data):
#         def py_index_to_label(values):
#             values = tf.map_fn(index_to_label_helper, values, dtype=tf.string)
#             return [values]
示例#17
0
if __name__ == "__main__":
    args = parse_cmdline_flags()

    # Load SSD model
    PATH_TO_FROZEN_GRAPH = args.detection_model_path
    detection_graph = tf.Graph()
    with detection_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as f:
            od_graph_def.ParseFromString(f.read())
            tf.import_graph_def(od_graph_def, name='')

    # Load CRNN model
    alphabet = '0123456789abcdefghijklmnopqrstuvwxyz'
    crnn = CRNN(32, 1, 37, 256)
    if torch.cuda.is_available():
        crnn = crnn.cuda()
    crnn.load_state_dict(torch.load(args.recognition_model_path))
    converter = utils.strLabelConverter(alphabet)
    transformer = dataset.resizeNormalize((100, 32))
    crnn.eval()

    # Open a video file or an image file
    cap = cv2.VideoCapture(args.input if args.input else 0)

    while cv2.waitKey(1) < 0:
        has_frame, frame = cap.read()
        if not has_frame:
            cv2.waitKey(0)
            break
示例#18
0
class Demo(object):
    def __init__(self, args):
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
        self.args = args
        self.alphabet = alphabetChinese
        nclass = len(self.alphabet) + 1
        nc = 1
        self.net = CRNN(args.imgH, nc, args.nh, nclass)
        self.converter = utils.strLabelConverter(self.alphabet, ignore_case=False)
        self.transformer = resizeNormalize(args.imgH)

        print('loading pretrained model from %s' % args.model_path)
        checkpoint = torch.load(args.model_path)
        if 'model_state_dict' in checkpoint.keys():
            checkpoint = checkpoint['model_state_dict']
        from collections import OrderedDict
        model_dict = OrderedDict()
        for k, v in checkpoint.items():
            if 'module' in k:
                model_dict[k[7:]] = v
            else:
                model_dict[k] = v
        self.net.load_state_dict(model_dict)

        if args.cuda and torch.cuda.is_available():
            print('available gpus is,', torch.cuda.device_count())
            self.net = torch.nn.DataParallel(self.net, output_dim=1).cuda()
        
        self.net.eval()
    
    def predict(self, image):
        image = self.transformer(image)
        if torch.cuda.is_available():
            image = image.cuda()
        image = image.view(1, *image.size())
        image = Variable(image)

        preds = self.net(image)
        _, preds = preds.max(2)
        preds = preds.transpose(1, 0).contiguous().view(-1)
        preds_size = Variable(torch.IntTensor([preds.size(0)]))
        raw_pred = self.converter.decode(preds.data, preds_size.data, raw=True)
        sim_pred = self.converter.decode(preds.data, preds_size.data, raw=False)
        print('%-20s => %-20s' % (raw_pred, sim_pred))

        return sim_pred

    def predict_batch(self, images):
        N = len(images)
        n_batch = N // self.args.batch_size
        n_batch += 1 if N % self.args.batch_size else 0
        res = []
        for i in range(n_batch):
            batch = images[i*self.args.batch_size : min((i+1)*self.args.batch_size, N)]
            maxW = 0
            for i in range(len(batch)):
                batch[i] = self.transformer(batch[i])
                imgW = batch[i].shape[2]
                maxW = max(maxW, imgW)
            
            for i in range(len(batch)):
                if batch[i].shape[2] < maxW:
                    batch[i] = torch.cat((batch[i], torch.zeros((1, self.args.imgH, maxW-batch[i].shape[2]), dtype=batch[i].dtype)), 2) 
            batch_imgs = torch.cat([t.unsqueeze(0) for t in batch], 0)
            preds = self.net(batch_imgs)
            preds_size = Variable(torch.IntTensor([preds.size(0)]*len(batch)))
            _, preds = preds.max(2)
            preds = preds.transpose(1, 0).contiguous().view(-1)
            raw_preds = self.converter.decode(preds.data, preds_size.data, raw=True)
            sim_preds = self.converter.decode(preds.data, preds_size.data, raw=False)
            for raw_pred, sim_pred in zip(raw_preds, sim_preds):
                print('%-20s => %-20s' % (raw_pred, sim_pred))
            res.extend(sim_preds)
        return res

    def inference(self, image_path, batch_pred=False):
        if os.path.isdir(image_path):
            file_list = os.listdir(image_path)
            image_list = [os.path.join(image_path, i) for i in file_list if i.rsplit('.')[-1].lower() in img_types] 
        else:
            image_list = [image_path]
        
        res = []
        images = []
        for img_path in image_list:
            image = Image.open(img_path).convert('L')
            if not batch_pred:
                sim_pred = self.predict(image)
                res.append(sim_pred)
            else:
                images.append(image)
        if batch_pred and images:
            res = self.predict_batch(images)
        return res
示例#19
0
import torch
from torch.autograd import Variable
import utils
import dataset
from PIL import Image
from models.crnn import CRNN


model_path = './data/crnn.pth'
img_path = './data/demo.png'
alphabet = '0123456789abcdefghijklmnopqrstuvwxyz'

model = CRNN(32, 1, 37, 256)

if torch.cuda.is_available():
    model = model.cuda()
print('loading pretrained model from %s' % model_path)

model.load_state_dict(torch.load(model_path))

converter = utils.strLabelConverter(alphabet)

transformer = dataset.resizeNormalize((100, 32))
image = Image.open(img_path).convert('L')
image = transformer(image)
if torch.cuda.is_available():
    image = image.cuda()
image = image.view(1, *image.size())
image = Variable(image)

model.eval()
示例#20
0
def main():
    config = Config()

    if not os.path.exists(config.expr_dir):
        os.makedirs(config.expr_dir)

    if torch.cuda.is_available() and not config.use_cuda:
        print("WARNING: You have a CUDA device, so you should probably set cuda in params.py to True")

    # 加载训练数据集
    train_dataset = HubDataset(config, "train", transform=None)

    train_kwargs = {'num_workers': 2, 'pin_memory': True,
                    'collate_fn': alignCollate(config.img_height, config.img_width, config.keep_ratio)} if torch.cuda.is_available() else {}

    training_data_batch = DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True, drop_last=False, **train_kwargs)

    # 加载定长校验数据集
    eval_dataset = HubDataset(config, "eval", transform=transforms.Compose([ResizeNormalize(config.img_height, config.img_width)]))
    eval_kwargs = {'num_workers': 2, 'pin_memory': False} if torch.cuda.is_available() else {}
    eval_data_batch = DataLoader(eval_dataset, batch_size=config.eval_batch_size, shuffle=False, drop_last=False, **eval_kwargs)

    # 加载不定长校验数据集
    # eval_dataset = HubDataset(config, "eval")
    # eval_kwargs = {'num_workers': 2, 'pin_memory': False,
    #                'collate_fn': alignCollate(config.img_height, config.img_width, config.keep_ratio)} if torch.cuda.is_available() else {}
    # eval_data_batch = DataLoader(eval_dataset, batch_size=config.eval_batch_size, shuffle=False, drop_last=False, **eval_kwargs)

    # 定义网络模型
    nclass = len(config.label_classes) + 1
    crnn = CRNN(config.img_height, config.nc, nclass, config.hidden_size, n_rnn=config.n_layers)
    # 加载预训练模型
    if config.pretrained != '':
        print('loading pretrained model from %s' % config.pretrained)
        crnn.load_state_dict(torch.load(config.pretrained))
    print(crnn)

    # Compute average for `torch.Variable` and `torch.Tensor`.
    loss_avg = utils.averager()

    # Convert between str and label.
    converter = utils.strLabelConverter(config.label_classes)

    criterion = CTCLoss()           # 定义损失函数

    # 设置占位符
    image = torch.FloatTensor(config.train_batch_size, 3, config.img_height, config.img_height)
    text = torch.LongTensor(config.train_batch_size * 5)
    length = torch.LongTensor(config.train_batch_size)

    if config.use_cuda and torch.cuda.is_available():
        criterion = criterion.cuda()
        image = image.cuda()
        crnn = crnn.to(config.device)

    image = Variable(image)
    text = Variable(text)
    length = Variable(length)

    # 设定优化器
    if config.adam:
        optimizer = optim.Adam(crnn.parameters(), lr=config.lr, betas=(config.beta1, 0.999))
    elif config.adadelta:
        optimizer = optim.Adadelta(crnn.parameters())
    else:
        optimizer = optim.RMSprop(crnn.parameters(), lr=config.lr)

    def val(net, criterion, eval_data_batch):
        print('Start val')
        for p in crnn.parameters():
            p.requires_grad = False
        net.eval()

        n_correct = 0
        loss_avg_eval = utils.averager()
        for data in eval_data_batch:
            cpu_images, cpu_texts = data
            batch_size = cpu_images.size(0)
            utils.loadData(image, cpu_images)
            t, l = converter.encode(cpu_texts)
            utils.loadData(text, t)
            utils.loadData(length, l)
            preds = crnn(image)
            preds_size = Variable(torch.LongTensor([preds.size(0)] * batch_size))
            cost = criterion(preds, text, preds_size, length) / batch_size
            loss_avg_eval.add(cost)         # 计算loss

            _, preds = preds.max(2)
            preds = preds.transpose(1, 0).contiguous().view(-1)
            sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
            cpu_texts_decode = []
            for i in cpu_texts:
                cpu_texts_decode.append(i)
            for pred, target in zip(sim_preds, cpu_texts_decode):       # 计算准确率
                if pred == target:
                    n_correct += 1

            raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:config.n_val_disp]
            for raw_pred, pred, gt in zip(raw_preds, sim_preds, cpu_texts_decode):
                print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))

        accuracy = n_correct / float(len(eval_dataset))
        print('Val loss: %f, accuray: %f' % (loss_avg.val(), accuracy))

    # 训练每个batch数据
    def train(net, criterion, optimizer, data):
        cpu_images, cpu_texts = data
        batch_size = cpu_images.size(0)             # 计算当前batch_size大小
        utils.loadData(image, cpu_images)
        t, l = converter.encode(cpu_texts)          # 转换为类别
        utils.loadData(text, t)
        utils.loadData(length, l)
        optimizer.zero_grad()                       # 清零梯度
        preds = net(image)
        preds_size = Variable(torch.LongTensor([preds.size(0)] * batch_size))
        cost = criterion(preds, text, preds_size, length) / batch_size
        cost.backward()
        optimizer.step()
        return cost

    for epoch in range(config.nepoch):
        i = 0
        for batch_data in training_data_batch:
            for p in crnn.parameters():
                p.requires_grad = True
            crnn.train()
            cost = train(crnn, criterion, optimizer, batch_data)
            loss_avg.add(cost)
            i += 1

            if i % config.displayInterval == 0:
                print('[%d/%d][%d/%d] Loss: %f' %
                      (epoch, config.nepoch, i, len(training_data_batch), loss_avg.val()))
                loss_avg.reset()

            # if i % config.valInterval == 0:
            #     val(crnn, criterion, eval_data_batch)
            #
            # # do checkpointing
            # if i % config.saveInterval == 0:
            #     torch.save(crnn.state_dict(), '{0}/netCRNN_{1}_{2}.pth'.format(config.expr_dir, epoch, i))

        val(crnn, criterion, eval_data_batch)
        torch.save(crnn.state_dict(), '{0}/netCRNN_{1}_end.pth'.format(config.expr_dir, epoch))
示例#21
0
def test(nets,
         data,
         abc,
         cuda,
         visualize,
         batch_size=1,
         tb_writer=None,
         n_iter=0,
         initial_title="",
         loss_function=None,
         is_trian=True,
         output_path=None,
         do_beam_search=False,
         do_results=False,
         word_lexicon=None):
    collate = lambda x: text_collate(x, do_mask=False)
    data_loader = DataLoader(data,
                             batch_size=1,
                             num_workers=2,
                             shuffle=False,
                             collate_fn=collate)
    stop_characters = ['-', '.', '༎', '༑', '།', '་']
    garbage = '-'
    count = 0
    tp = 0
    avg_ed = 0
    avg_no_stop_ed = 0
    avg_accuracy = 0
    avg_loss = 0
    min_ed = 1000
    iterator = tqdm(data_loader)
    all_pred_text = all_label_text = all_im_pathes = []
    test_letter_statistics = Statistics()
    im_by_error = {}

    for i, sample in enumerate(iterator):
        if is_trian and (i > 1000):
            break
        imgs = Variable(sample["img"])

        img_seq_lens = sample["im_seq_len"]
        losses = 0
        all_nets_out = []
        for neti, net in enumerate(nets):
            if cuda:
                imgs = imgs.cuda(neti)
            orig_seq = net(imgs,
                           img_seq_lens,
                           decode=False,
                           do_beam_search=do_beam_search)
            if loss_function is not None:
                labels_flatten = Variable(sample["seq"]).view(-1)
                label_lens = Variable(sample["seq_len"].int())
                loss = loss_function(
                    orig_seq, labels_flatten,
                    Variable(torch.IntTensor(np.array(img_seq_lens))),
                    label_lens) / batch_size
                losses += loss.data[0]
            orig_seq = orig_seq.transpose(1, 0)
            orig_seq = orig_seq.cpu().data.numpy()
            net_pred = np.argmax(orig_seq, axis=2)
            all_nets_out.append(net_pred)
        all_nets_out = np.stack(all_nets_out)

        all_nets_out, _ = stats.mode(
            all_nets_out,
            axis=0,
        )
        all_nets_out = np.squeeze(all_nets_out, axis=0)
        out = []
        for net_out in all_nets_out:
            net_out = net_out.tolist()
            out.append(CRNN.label_to_string(net_out, nets[0].lexicon))

        losses /= float(len(nets))
        avg_loss += losses
        gt = (sample["seq"].numpy()).tolist()
        lens = sample["seq_len"].numpy().tolist()
        labels_flatten = Variable(sample["seq"]).view(-1)
        label_lens = Variable(sample["seq_len"].int())
        '''
        if output_path is not None:
            preds_text = net.decode(orig_seq, data.get_lexicon())
            all_pred_text = all_pred_text + [''.join(c for c in pd if c != garbage)+'\n' for pd in preds_text]

            label_text = net.decode_flatten(labels_flatten, label_lens, data.get_lexicon())
            all_label_text = all_label_text + [lb + '\n' for lb in label_text]

            all_im_pathes.append(sample["im_path"] + '\n')#[imp +'\n' for imp in sample["im_path"]]
        '''

        #if i == 0:
        #    if tb_writer is not None:
        #        print_data_visuals(net, tb_writer, data.get_lexicon(), sample["img"], labels_flatten, label_lens, orig_seq, n_iter,
        #                       initial_title)

        pos = 0
        key = ''
        for i in range(len(out)):
            gts = ''.join(abc[c] for c in gt[pos:pos + lens[i]])

            pos += lens[i]
            if gts == out[i]:
                tp += 1
            else:
                cur_out = ''.join(c for c in out[i] if c != garbage)
                cur_gts = ''.join(c for c in gts if c != garbage)
                cur_out_no_stops = ''.join(c for c in out[i]
                                           if not c in stop_characters)
                cur_gts_no_stops = ''.join(c for c in gts
                                           if not c in stop_characters)
                cur_ed = editdistance.eval(cur_out, cur_gts) / len(cur_gts)
                #if word_lexicon is not None:
                #    closest_word = get_close_matches(cur_out, word_lexicon, n=1, cutoff=0.2)
                #else:
                #    closest_word = cur_out

                #if len(closest_word) > 0 and closest_word[0] == cur_gts:
                #    avg_accuracy += 1

                errors, matches, bp = my_edit_distance_backpointer(
                    cur_out_no_stops, cur_gts_no_stops)
                test_letter_statistics.add_data(bp)
                #my_no_stop_ed = errors / max(len(cur_out_no_stops), len(cur_gts_no_stops))
                #cur_no_stop_ed = editdistance.eval(cur_out_no_stops, cur_gts_no_stops) / max(len(cur_out_no_stops), len(cur_gts_no_stops))
                if do_results:
                    im_by_error[sample["im_path"]] = cur_ed
                my_no_stop_ed = errors / len(cur_gts_no_stops)
                cur_no_stop_ed = editdistance.eval(
                    cur_out_no_stops, cur_gts_no_stops) / len(cur_gts_no_stops)

                if my_no_stop_ed != cur_no_stop_ed:
                    print('old ed: {} , vs. new ed: {}\n'.format(
                        my_no_stop_ed, cur_no_stop_ed))
                avg_no_stop_ed += cur_no_stop_ed
                avg_ed += cur_ed
                if cur_ed < min_ed: min_ed = cur_ed
            count += 1
            if visualize:
                status = "pred: {}; gt: {}".format(out[i], gts)
                iterator.set_description(status)
                img = imgs[i].permute(1, 2,
                                      0).cpu().data.numpy().astype(np.uint8)
                cv2.imshow("img", img)
                key = chr(cv2.waitKey() & 255)
                if key == 'q':
                    break

        if not visualize:
            iterator.set_description("acc: {0:.4f}; avg_ed: {0:.4f}".format(
                float(tp) / float(count),
                float(avg_ed) / float(count)))
    #with open(output_path + '_{}_{}_statistics.pkl'.format(initial_title,n_iter), 'wb') as sf:

    #    pkl.dump(test_letter_statistics.total_actions_hists, sf)

    if do_results and output_path is not None:
        print('printing results! :)')
        sorted_im_by_error = sorted(im_by_error.items(),
                                    key=operator.itemgetter(1))
        sorted_im = [key for (key, value) in sorted_im_by_error]
        all_im_pathes_no_new_line = [
            im.replace('\n', '') for im in all_im_pathes
        ]
        printed_res_best = ""
        printed_res_worst = ""
        for im in sorted_im[:20]:
            im_id = all_im_pathes_no_new_line.index(im)
            pred = all_pred_text[im_id]
            label = all_label_text[im_id]
            printed_res_best += im + '\n' + label + pred

        for im in list(reversed(sorted_im))[:20]:
            im_id = all_im_pathes_no_new_line.index(im)
            pred = all_pred_text[im_id]
            label = all_label_text[im_id]
            printed_res_worst += im + '\n' + label + pred

        with open(
                output_path + '_{}_{}_sorted_images_by_errors.txt'.format(
                    initial_title, n_iter), 'w') as fp:
            fp.writelines([
                key + ',' + str(value) + '\n'
                for (key, value) in sorted_im_by_error
            ])

        with open(
                output_path +
                '_{}_{}_res_on_best.txt'.format(initial_title, n_iter),
                'w') as fp:
            fp.writelines([printed_res_best])
            with open(
                    output_path +
                    '_{}_{}_res_on_worst.txt'.format(initial_title, n_iter),
                    'w') as fp:
                fp.writelines([printed_res_worst])
        os.makedirs(output_path, exist_ok=True)
        with open(
                output_path + '_{}_{}_pred.txt'.format(initial_title, n_iter),
                'w') as fp:
            fp.writelines(all_pred_text)
        with open(
                output_path + '_{}_{}_label.txt'.format(initial_title, n_iter),
                'w') as fp:
            fp.writelines(all_label_text)
        with open(output_path + '_{}_{}_im.txt'.format(initial_title, n_iter),
                  'w') as fp:
            fp.writelines(all_im_pathes)
        stop_characters = ['-', '.', '༎', '༑', '།', '་']

        all_pred_text = [
            ''.join(c for c in line if not c in stop_characters)
            for line in all_pred_text
        ]
        with open(
                output_path +
                '_{}_{}_pred_no_stopchars.txt'.format(initial_title, n_iter),
                'w') as rf:
            rf.writelines(all_pred_text)
        all_label_text = [
            ''.join(c for c in line if not c in stop_characters)
            for line in all_label_text
        ]
        with open(
                output_path +
                '_{}_{}_label_no_stopchars.txt'.format(initial_title, n_iter),
                'w') as rf:
            rf.writelines(all_label_text)

    acc = float(avg_accuracy) / float(count)
    avg_ed = float(avg_ed) / float(count)
    avg_no_stop_ed = float(avg_no_stop_ed) / float(count)
    if loss_function is not None:
        avg_loss = float(avg_loss) / float(count)
        return acc, avg_ed, avg_no_stop_ed, avg_loss
    return acc, avg_ed, avg_no_stop_ed
示例#22
0
    # testdataset = PathDataset(opt.valRoot, alphabetChinese)

    testdataset = SytheticChinese(opt.valRoot, 'test')
    val_loader = torch.utils.data.DataLoader(testdataset,
                                             shuffle=False,
                                             batch_size=opt.batch_size,
                                             num_workers=int(opt.workers),
                                             collate_fn=alignCollate(
                                                 imgH=imgH,
                                                 imgW=imgW,
                                                 keep_ratio=keep_ratio))

    alphabet = keys.alphabetChinese
    print("char num ", len(alphabet))
    model = CRNN(32, 1, len(alphabet) + 1, 256, 1)

    converter = strLabelConverter(''.join(alphabet))

    state_dict = torch.load("../SceneOcr/model/ocr-lstm.pth",
                            map_location=lambda storage, loc: storage)
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k
        if "num_batches_tracked" not in k:
            # name = name.replace('module.', '')  # remove `module.`
            new_state_dict[name] = v
    model.cuda()
    model = torch.nn.DataParallel(model, device_ids=[0, 1, 2])

    # load params
示例#23
0
                   "early_stopping_patience": 20,
                   "reduce_lr_on_plateau_monitor": "val_loss",
                   "reduce_lr_on_plateau_min_lr": 1e-6,
                   "reduce_lr_on_plateau_factor": .33333,
                   "reduce_lr_on_plateau_patience": 10,
               })
    config = wandb.config

    callbacks = get_callbacks(
        early_stopping_patience=config.early_stopping_patience,
        checkpoint_path=checkpoint_path,
        reduce_lr_on_plateau_monitor=config.reduce_lr_on_plateau_monitor,
        reduce_lr_on_plateau_factor=config.reduce_lr_on_plateau_factor,
        reduce_lr_on_plateau_patience=config.reduce_lr_on_plateau_patience,
        reduce_lr_on_plateau_min_lr=config.reduce_lr_on_plateau_min_lr)

    model = CRNN(n_classes=config.num_classes,
                 batch_size=config.batch_size,
                 lr=config.learning_rate,
                 optimizer_type=config.optimizer,
                 reg=config.l2_reg,
                 cnn_weights_path=cnn_weights_path)

    model.fit(train,
              epochs=config.epochs,
              steps_per_epoch=steps_per_epoch,
              validation_data=valid,
              validation_steps=validation_steps,
              callbacks=callbacks,
              verbose=1)