Ejemplo n.º 1
0
def main(data_path, abc, seq_proj, backend, snapshot, input_size, gpu,
         visualize):
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    cuda = True if gpu is not '' else False
    abc = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
    #print(abc)
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu
    input_size = [int(x) for x in input_size.split('x')]
    transform = Compose(
        [Rotation(), Resize(size=(input_size[0], input_size[1]))])
    if data_path is not None:

        data = LoadDataset(data_path=data_path,
                           mode="test",
                           transform=transform)

    seq_proj = [int(x) for x in seq_proj.split('x')]

    #net = load_model(abc, seq_proj, backend, snapshot, cuda)
    net = CRNN(abc=abc, seq_proj=seq_proj, backend=backend)
    #net = nn.DataParallel(net)
    if snapshot is not None:
        load_weights(net, torch.load(snapshot))
    if cuda:
        net = net.cuda()
    #import pdb;pdb.set_trace()
    net = net.eval()
    detect(net, data, cuda, visualize)
Ejemplo n.º 2
0
def load_model(lexicon,
               seq_proj=[0, 0],
               backend='resnet18',
               base_model_dir=None,
               snapshot=None,
               cuda=True,
               do_beam_search=False,
               dropout_conv=False,
               dropout_rnn=False,
               dropout_output=False,
               do_ema=False,
               ada_after_rnn=False,
               ada_before_rnn=False,
               rnn_hidden_size=128):
    net = CRNN(lexicon=lexicon,
               seq_proj=seq_proj,
               backend=backend,
               base_model_dir=base_model_dir,
               do_beam_search=do_beam_search,
               dropout_conv=dropout_conv,
               dropout_rnn=dropout_rnn,
               dropout_output=dropout_output,
               do_ema=do_ema,
               ada_after_rnn=ada_after_rnn,
               ada_before_rnn=ada_before_rnn,
               rnn_hidden_size=rnn_hidden_size)
    #net = nn.DataParallel(net)
    if snapshot is not None:
        print('snapshot is: {}'.format(snapshot))
        load_weights(net, torch.load(snapshot))
    if cuda:
        print('setting network on gpu')
        net = net.cuda()
        print('set network on gpu')
    return net
def load_model(abc,
               seq_proj=[0, 0],
               backend='resnet18',
               snapshot=None,
               cuda=True):
    net = CRNN(abc=abc, seq_proj=seq_proj, backend=backend)
    net = nn.DataParallel(net)
    if snapshot is not None:
        load_weights(net, torch.load(snapshot))
    if cuda:
        net = net.cuda()
    return net
Ejemplo n.º 4
0
    val_loader = torch.utils.data.DataLoader(testdataset,
                                             shuffle=False,
                                             batch_size=opt.batch_size,
                                             num_workers=int(opt.workers),
                                             collate_fn=alignCollate(
                                                 imgH=imgH,
                                                 imgW=imgW,
                                                 keep_ratio=keep_ratio))

    alphabet = keys.alphabetChinese
    print("char num ", len(alphabet))
    model = CRNN(32, 1, len(alphabet) + 1, 256, 1)

    converter = strLabelConverter(''.join(alphabet))

    state_dict = torch.load("../SceneOcr/model/ocr-lstm.pth",
                            map_location=lambda storage, loc: storage)
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k
        if "num_batches_tracked" not in k:
            # name = name.replace('module.', '')  # remove `module.`
            new_state_dict[name] = v
    model.cuda()
    model = torch.nn.DataParallel(model, device_ids=[0, 1, 2])

    # load params
    model.load_state_dict(new_state_dict)
    model.eval()

    curAcc = val(model, converter, val_loader, max_iter=5)
Ejemplo n.º 5
0
import torch
from torch.autograd import Variable
import utils
import dataset
from PIL import Image
from models.crnn import CRNN


model_path = './data/crnn.pth'
img_path = './data/demo.png'
alphabet = '0123456789abcdefghijklmnopqrstuvwxyz'

model = CRNN(32, 1, 37, 256)

if torch.cuda.is_available():
    model = model.cuda()
print('loading pretrained model from %s' % model_path)

model.load_state_dict(torch.load(model_path))

converter = utils.strLabelConverter(alphabet)

transformer = dataset.resizeNormalize((100, 32))
image = Image.open(img_path).convert('L')
image = transformer(image)
if torch.cuda.is_available():
    image = image.cuda()
image = image.view(1, *image.size())
image = Variable(image)

model.eval()
Ejemplo n.º 6
0
    args = parse_cmdline_flags()

    # Load SSD model
    PATH_TO_FROZEN_GRAPH = args.detection_model_path
    detection_graph = tf.Graph()
    with detection_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as f:
            od_graph_def.ParseFromString(f.read())
            tf.import_graph_def(od_graph_def, name='')

    # Load CRNN model
    alphabet = '0123456789abcdefghijklmnopqrstuvwxyz'
    crnn = CRNN(32, 1, 37, 256)
    if torch.cuda.is_available():
        crnn = crnn.cuda()
    crnn.load_state_dict(torch.load(args.recognition_model_path))
    converter = utils.strLabelConverter(alphabet)
    transformer = dataset.resizeNormalize((100, 32))
    crnn.eval()

    # Open a video file or an image file
    cap = cv2.VideoCapture(args.input if args.input else 0)

    while cv2.waitKey(1) < 0:
        has_frame, frame = cap.read()
        if not has_frame:
            cv2.waitKey(0)
            break

        im_height, im_width, _ = frame.shape
def main():

    print(torch.__version__)

    with open('config.yaml') as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    print(torch.cuda.is_available())
    torch.backends.cudnn.benchmark = True

    char_set = config['char_set']
    # if config['method'] == 'ctc':
    char2idx_ctc, idx2char_ctc = get_char_dict_ctc(char_set)
    char2idx_att, idx2char_att = get_char_dict_attention(char_set)
    config['char2idx_ctc'] = char2idx_ctc
    config['idx2char_ctc'] = idx2char_ctc
    config['char2idx_att'] = char2idx_att
    config['idx2char_att'] = idx2char_att

    batch_size = config['batch_size']

    if not os.path.exists(config['save_path']):
        os.mkdir(config['save_path'])
    print(config)

    train_dataset = TextRecDataset(config, phase='train')
    val_dataset = TextRecDataset(config, phase='val')
    test_dataset = TextRecDataset(config, phase='test')
    trainloader = data.DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=cpu_count(),
                                  pin_memory=False)

    valloader = data.DataLoader(val_dataset,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=cpu_count(),
                                pin_memory=False)

    testloader = data.DataLoader(test_dataset,
                                 batch_size=batch_size,
                                 shuffle=False,
                                 num_workers=cpu_count(),
                                 pin_memory=False)

    class_num = len(config['char_set']) + 1
    print('class_num', class_num)
    model = CRNN(class_num)
    # decoder = Decoder(class_num, config['max_string_len'], char2idx_att)
    attention_head = AttentionHead(class_num, config['max_string_len'], char2idx_att)

    # criterion = nn.CTCLoss(blank=char2idx['-'], reduction='mean')
    criterion_ctc = CTCFocalLoss(blank=char2idx_ctc['-'], gamma=0.5)
    criterion_att = nn.CrossEntropyLoss(reduction='none')

    if config['use_gpu']:
        model = model.cuda()
        # decoder = decoder.cuda()
        attention_head = attention_head.cuda()
    summary(model, (1, 32, 400))

    # model = torch.nn.DataParallel(model)

    # optimizer = torch.optim.Adam([{'params': model.parameters()}], lr=1e-2, weight_decay=5e-4)
    optimizer = torch.optim.SGD([{'params': model.parameters()},
                                 {'params': attention_head.parameters()}], lr=0.001, momentum=0.9, weight_decay=5e-4)

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[500, 800], gamma=0.1)

    print('train start, total batches %d' % len(trainloader))
    iter_cnt = 0
    for i in range(1, config['epochs']+1):
        start = time.time()
        model.train()
        attention_head.train()
        for j, batch in enumerate(trainloader):

            iter_cnt += 1
            imgs = batch[0].cuda()
            labels_length = batch[1].cuda()
            labels_str = batch[2]
            labels_ctc = batch[3].cuda().long()
            labels_ctc_mask = batch[4].cuda().float()
            labels_att = batch[5].cuda().long()
            labels_att_mask = batch[6].cuda().float()

            if config['method'] == 'ctc':
                # CTC loss
                outputs, cnn_features = model(imgs)
                log_prob = outputs.log_softmax(dim=2)
                t,n,c = log_prob.size(0),log_prob.size(1),log_prob.size(2)
                input_length = (torch.ones((n,)) * t).cuda().int()
                loss_ctc = criterion_ctc(log_prob, labels_ctc, input_length, labels_length)

                # attention loss   
                outputs = attention_head(cnn_features, labels_att)
                probs = outputs.permute(1, 2, 0)
                losses_att = criterion_att(probs, labels_att)
                losses_att = losses_att * labels_att_mask
                losses_att = losses_att.sum() / labels_att_mask.sum()

                loss = loss_ctc + losses_att

            else:
                # cross_entropy loss
                outputs_ctc, sqs = model(imgs)
                outputs_att = decoder(sqs, label_att)

                outputs = outputs_att.permute(1, 2, 0)
                losses = criterion(outputs, label_att)
                losses = losses * labels_att_mask
                loss = losses.sum() / labels_att_mask.sum()
 
                # attention loss   

            optimizer.zero_grad()            
            loss.backward()
            # nn.utils.clip_grad_norm_(model.parameters(), max_norm=5)
            optimizer.step()

            if iter_cnt % config['print_freq'] == 0:
                print('epoch %d, iter %d, train loss %f' % (i, iter_cnt, loss.item()))

        print('epoch %d, time %f' % (i, (time.time() - start)))
        scheduler.step()

        print("validating...")
        
        if config['method'] == 'ctc':
            eval_ctc(model, valloader, idx2char_ctc)
        else:
            eval_attention(model, decoder, valloader, idx2char_att)

        if i % config['test_freq'] == 0:
            print("testing...")
            if config['method'] == 'ctc':
                line_acc, rec_score = eval_ctc(model, testloader, idx2char_ctc)
            else:
                line_acc, rec_score = eval_attention(model, decoder, testloader, idx2char_att)

        if i % config['save_freq'] == 0:
            save_file_name = f"epoch_{i}_acc_{line_acc:.3f}_rec_score_{rec_score:.3f}.pth"
            save_file = os.path.join(config['save_path'], save_file_name)
            torch.save(model.state_dict(), save_file)