Пример #1
0
def demo(opt):
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    # ctc 加上了 [black], attention 加上了 [GO], [s]
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)
    model = torch.nn.DataParallel(model).to(device)

    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model, map_location=device))

    # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
    AlignCollate_demo = AlignCollate(imgH=opt.imgH,
                                     imgW=opt.imgW,
                                     keep_ratio_with_pad=opt.PAD)
    demo_data = RawDataset(root=opt.image_folder, opt=opt)  # use RawDataset
    demo_loader = torch.utils.data.DataLoader(demo_data,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=int(opt.workers),
                                              collate_fn=AlignCollate_demo,
                                              pin_memory=True)

    # predict
    model.eval()
    with torch.no_grad():
        for image_tensors, image_path_list in demo_loader:
            batch_size = image_tensors.size(0)
            image = image_tensors.to(device)
            # For max length prediction
            length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                              batch_size).to(device)
            text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                             1).fill_(0).to(device)

            if 'CTC' in opt.Prediction:
                preds = model(image, text_for_pred)

                # Select max probabilty (greedy decoding) then decode index to character
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                _, preds_index = preds.max(2)
                preds_index = preds_index.view(-1)
                preds_str = converter.decode(preds_index.data, preds_size.data)

            else:
                preds = model(image, text_for_pred, is_train=False)

                # select max probabilty (greedy decoding) then decode index to character
                _, preds_index = preds.max(2)
                preds_str = converter.decode(preds_index, length_for_pred)

            log = open(f'./log_demo_result.txt', 'a')
            dashed_line = '-' * 80
            head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score'

            print(f'{dashed_line}\n{head}\n{dashed_line}')
            log.write(f'{dashed_line}\n{head}\n{dashed_line}\n')

            preds_prob = F.softmax(preds, dim=2)
            preds_max_prob, _ = preds_prob.max(dim=2)
            for img_name, pred, pred_max_prob in zip(image_path_list,
                                                     preds_str,
                                                     preds_max_prob):
                if 'Attn' in opt.Prediction:
                    pred_EOS = pred.find('[s]')
                    pred = pred[:
                                pred_EOS]  # prune after "end of sentence" token ([s])
                    pred_max_prob = pred_max_prob[:pred_EOS]

                # calculate confidence score (= multiply of pred_max_prob)
                confidence_score = pred_max_prob.cumprod(dim=0)[-1]

                print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}')
                log.write(
                    f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n')

            log.close()
Пример #2
0
def train(opt):
    plotDir = os.path.join(opt.exp_dir,opt.exp_name,'plots')
    if not os.path.exists(plotDir):
        os.makedirs(plotDir)
    
    lib.print_model_settings(locals().copy())

    """ dataset preparation """
    if not opt.data_filtering_off:
        print('Filtering the images containing characters which are not in opt.character')
        print('Filtering the images whose label is longer than opt.batch_max_length')
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')

    #considering the real images for discriminator
    opt.batch_size = opt.batch_size*2

    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(os.path.join(opt.exp_dir,opt.exp_name,'log_dataset.txt'), 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=opt.batch_size,
        shuffle=False,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    
    model = AdaINGenV4(opt)
    ocrModel = Model(opt)
    disModel = MsImageDisV1(opt)
    
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
          opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
          opt.SequenceModeling, opt.Prediction)

    
    #  weight initialization
    for currModel in [model, ocrModel, disModel]:
        for name, param in currModel.named_parameters():
            if 'localization_fc2' in name:
                print(f'Skip {name} as it is already initialized')
                continue
            try:
                if 'bias' in name:
                    init.constant_(param, 0.0)
                elif 'weight' in name:
                    init.kaiming_normal_(param)
            except Exception as e:  # for batchnorm.
                if 'weight' in name:
                    param.data.fill_(1)
                continue
    
    
    # data parallel for multi-GPU
    ocrModel = torch.nn.DataParallel(ocrModel).to(device)
    if not opt.ocrFixed:
        ocrModel.train()
    else:
        ocrModel.module.Transformation.eval()
        ocrModel.module.FeatureExtraction.eval()
        ocrModel.module.AdaptiveAvgPool.eval()
        # ocrModel.module.SequenceModeling.eval()
        ocrModel.module.Prediction.eval()

    model = torch.nn.DataParallel(model).to(device)
    model.train()
    
    disModel = torch.nn.DataParallel(disModel).to(device)
    disModel.train()

    if opt.modelFolderFlag:
        
        if len(glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth")))>0:
            opt.saved_synth_model = glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth"))[-1]
        
        if len(glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_dis.pth")))>0:
            opt.saved_dis_model = glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_dis.pth"))[-1]

    #loading pre-trained model
    if opt.saved_ocr_model != '' and opt.saved_ocr_model != 'None':
        print(f'loading pretrained ocr model from {opt.saved_ocr_model}')
        if opt.FT:
            ocrModel.load_state_dict(torch.load(opt.saved_ocr_model), strict=False)
        else:
            ocrModel.load_state_dict(torch.load(opt.saved_ocr_model))
    print("OCRModel:")
    print(ocrModel)

    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        print(f'loading pretrained synth model from {opt.saved_synth_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_synth_model), strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_synth_model))
    print("SynthModel:")
    print(model)

    if opt.saved_dis_model != '' and opt.saved_dis_model != 'None':
        print(f'loading pretrained discriminator model from {opt.saved_dis_model}')
        if opt.FT:
            disModel.load_state_dict(torch.load(opt.saved_dis_model), strict=False)
        else:
            disModel.load_state_dict(torch.load(opt.saved_dis_model))
    print("DisModel:")
    print(disModel)

    """ setup loss """
    if 'CTC' in opt.Prediction:
        ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)  # ignore [GO] token = ignore index 0
    
    recCriterion = torch.nn.L1Loss()
    if opt.styleLoss == 'l1':
        styleRecCriterion = torch.nn.L1Loss()
    elif opt.styleLoss == 'triplet':
        styleRecCriterion = torch.nn.TripletMarginLoss(margin=opt.tripletMargin, p=1)
    #for validation; check only positive pairs
    styleTestRecCriterion = torch.nn.L1Loss()

    # loss averager
    loss_avg_ocr = Averager()
    loss_avg = Averager()
    loss_avg_dis = Averager()

    loss_avg_ocrRecon_1 = Averager()
    loss_avg_ocrRecon_2 = Averager()
    loss_avg_gen = Averager()
    loss_avg_imgRecon = Averager()
    loss_avg_styRecon = Averager()

    ##---------------------------------------##
    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.optim=='adam':
        optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay)
    else:
        optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps, weight_decay=opt.weight_decay)
    print("SynthOptimizer:")
    print(optimizer)
    

    #filter parameters for OCR training
    ocr_filtered_parameters = []
    ocr_params_num = []
    for p in filter(lambda p: p.requires_grad, ocrModel.parameters()):
        ocr_filtered_parameters.append(p)
        ocr_params_num.append(np.prod(p.size()))
    print('OCR Trainable params num : ', sum(ocr_params_num))


    # setup optimizer
    if opt.optim=='adam':
        ocr_optimizer = optim.Adam(ocr_filtered_parameters, lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay)
    else:
        ocr_optimizer = optim.Adadelta(ocr_filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps, weight_decay=opt.weight_decay)
    print("OCROptimizer:")
    print(ocr_optimizer)

    #filter parameters for OCR training
    dis_filtered_parameters = []
    dis_params_num = []
    for p in filter(lambda p: p.requires_grad, disModel.parameters()):
        dis_filtered_parameters.append(p)
        dis_params_num.append(np.prod(p.size()))
    print('Dis Trainable params num : ', sum(dis_params_num))

    # setup optimizer
    if opt.optim=='adam':
        dis_optimizer = optim.Adam(dis_filtered_parameters, lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay)
    else:
        dis_optimizer = optim.Adadelta(dis_filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps, weight_decay=opt.weight_decay)
    print("DisOptimizer:")
    print(dis_optimizer)
    ##---------------------------------------##

    """ final options """
    with open(os.path.join(opt.exp_dir,opt.exp_name,'opt.txt'), 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)

    """ start training """
    start_iter = 0
    
    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        try:
            start_iter = int(opt.saved_synth_model.split('_')[-2].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    
    #get schedulers
    scheduler = get_scheduler(optimizer,opt)
    ocr_scheduler = get_scheduler(ocr_optimizer,opt)
    dis_scheduler = get_scheduler(dis_optimizer,opt)

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    best_accuracy_ocr = -1
    best_norm_ED_ocr = -1
    iteration = start_iter
    cntr=0


    while(True):
        # train part
        
        if opt.lr_policy !="None":
            scheduler.step()
            ocr_scheduler.step()
            dis_scheduler.step()
            
        image_tensors_all, labels_1_all, labels_2_all = train_dataset.get_batch()
        
        # ## comment
        # pdb.set_trace()
        # for imgCntr in range(image_tensors.shape[0]):
        #     save_image(tensor2im(image_tensors[imgCntr]),'temp/'+str(imgCntr)+'.png')
        # pdb.set_trace()
        # ###
        # print(cntr)
        cntr+=1
        disCnt = int(image_tensors_all.size(0)/2)
        image_tensors, image_tensors_real, labels_gt, labels_2 = image_tensors_all[:disCnt], image_tensors_all[disCnt:disCnt+disCnt], labels_1_all[:disCnt], labels_2_all[:disCnt]

        image = image_tensors.to(device)
        image_real = image_tensors_real.to(device)
        batch_size = image.size(0)

        
        ##-----------------------------------##
        #generate text(labels) from ocr.forward
        if opt.ocrFixed:
            # ocrModel.eval()
            length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
            text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)
            
            if 'CTC' in opt.Prediction:
                preds = ocrModel(image, text_for_pred)
                preds = preds[:, :text_for_loss.shape[1] - 1, :]
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                _, preds_index = preds.max(2)
                labels_1 = converter.decode(preds_index.data, preds_size.data)
            else:
                preds = ocrModel(image, text_for_pred, is_train=False)
                _, preds_index = preds.max(2)
                labels_1 = converter.decode(preds_index, length_for_pred)
                for idx, pred in enumerate(labels_1):
                    pred_EOS = pred.find('[s]')
                    labels_1[idx] = pred[:pred_EOS]  # prune after "end of sentence" token ([s])
            # ocrModel.train()
        else:
            labels_1 = labels_gt
        
        ##-----------------------------------##

        text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length)
        text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length)
        
        #forward pass from style and word generator
        images_recon_1, images_recon_2, style = model(image, text_1, text_2)

        if 'CTC' in opt.Prediction:
            
            if not opt.ocrFixed:
                #ocr training with orig image
                preds_ocr = ocrModel(image, text_1)
                preds_size_ocr = torch.IntTensor([preds_ocr.size(1)] * batch_size)
                preds_ocr = preds_ocr.log_softmax(2).permute(1, 0, 2)

                ocrCost_train = ocrCriterion(preds_ocr, text_1, preds_size_ocr, length_1)

            
            #content loss for reconstructed images
            preds_1 = ocrModel(images_recon_1, text_1)
            preds_size_1 = torch.IntTensor([preds_1.size(1)] * batch_size)
            preds_1 = preds_1.log_softmax(2).permute(1, 0, 2)

            preds_2 = ocrModel(images_recon_2, text_2)
            preds_size_2 = torch.IntTensor([preds_2.size(1)] * batch_size)
            preds_2 = preds_2.log_softmax(2).permute(1, 0, 2)
            ocrCost_1 = ocrCriterion(preds_1, text_1, preds_size_1, length_1)
            ocrCost_2 = ocrCriterion(preds_2, text_2, preds_size_2, length_2)
            # ocrCost = 0.5*( ocrCost_1 + ocrCost_2 )

        else:
            if not opt.ocrFixed:
                #ocr training with orig image
                preds_ocr = ocrModel(image, text_1[:, :-1])  # align with Attention.forward
                target_ocr = text_1[:, 1:]  # without [GO] Symbol

                ocrCost_train = ocrCriterion(preds_ocr.view(-1, preds_ocr.shape[-1]), target_ocr.contiguous().view(-1))

            #content loss for reconstructed images
            preds_1 = ocrModel(images_recon_1, text_1[:, :-1], is_train=False)  # align with Attention.forward
            target_1 = text_1[:, 1:]  # without [GO] Symbol

            preds_2 = ocrModel(images_recon_2, text_2[:, :-1], is_train=False)  # align with Attention.forward
            target_2 = text_2[:, 1:]  # without [GO] Symbol

            ocrCost_1 = ocrCriterion(preds_1.view(-1, preds_1.shape[-1]), target_1.contiguous().view(-1))
            ocrCost_2 = ocrCriterion(preds_2.view(-1, preds_2.shape[-1]), target_2.contiguous().view(-1))
            # ocrCost = 0.5*(ocrCost_1+ocrCost_2)
        
        if not opt.ocrFixed:
            #training OCR
            ocrModel.zero_grad()
            ocrCost_train.backward()
            # torch.nn.utils.clip_grad_norm_(ocrModel.parameters(), opt.grad_clip)  # gradient clipping with 5 (Default)
            ocr_optimizer.step()
            #if ocr is fixed; ignore this loss
            loss_avg_ocr.add(ocrCost_train)
        else:
            loss_avg_ocr.add(torch.tensor(0.0))

        
        #Domain discriminator: Dis update
        disModel.zero_grad()
        disCost = opt.disWeight*0.5*(disModel.module.calc_dis_loss(images_recon_1.detach(), image_real) + disModel.module.calc_dis_loss(images_recon_2.detach(), image))
        disCost.backward()
        # torch.nn.utils.clip_grad_norm_(disModel.parameters(), opt.grad_clip)  # gradient clipping with 5 (Default)
        dis_optimizer.step()
        loss_avg_dis.add(disCost)
        
        # #[Style Encoder] + [Word Generator] update
        #Adversarial loss
        disGenCost = 0.5*(disModel.module.calc_gen_loss(images_recon_1)+disModel.module.calc_gen_loss(images_recon_2))

        #Input reconstruction loss
        recCost = recCriterion(images_recon_1,image)

        #Pair style reconstruction loss
        if opt.styleReconWeight == 0.0:
            styleRecCost = torch.tensor(0.0)
        else:
            if opt.styleLoss == 'l1':
                if opt.styleDetach:
                    styleRecCost = styleRecCriterion(model(images_recon_2, None, None, styleFlag=True), style.detach())
                else:
                    styleRecCost = styleRecCriterion(model(images_recon_2, None, None, styleFlag=True), style)
            elif opt.styleLoss=='triplet':
                if opt.styleDetach:
                    styleRecCost = styleRecCriterion(model(images_recon_2, None, None, styleFlag=True), style.detach(), model(image_real, None, None, styleFlag=True))
                else:
                    styleRecCost = styleRecCriterion(model(images_recon_2, None, None, styleFlag=True), style, model(image_real, None, None, styleFlag=True))

        #OCR Content cost
        ocrCost = 0.5*(ocrCost_1+ocrCost_2)

        cost = opt.ocrWeight*ocrCost + opt.reconWeight*recCost + opt.disWeight*disGenCost + opt.styleReconWeight*styleRecCost

        model.zero_grad()
        ocrModel.zero_grad()
        disModel.zero_grad()
        cost.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()
        loss_avg.add(cost)

        #Individual losses
        loss_avg_ocrRecon_1.add(opt.ocrWeight*0.5*ocrCost_1)
        loss_avg_ocrRecon_2.add(opt.ocrWeight*0.5*ocrCost_2)
        loss_avg_gen.add(opt.disWeight*disGenCost)
        loss_avg_imgRecon.add(opt.reconWeight*recCost)
        loss_avg_styRecon.add(opt.styleReconWeight*styleRecCost)

        # validation part
        if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' 
            
            #Save training images
            os.makedirs(os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration)), exist_ok=True)
            for trImgCntr in range(batch_size):
                try:
                    save_image(tensor2im(image[trImgCntr].detach()),os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration),str(trImgCntr)+'_input_'+labels_gt[trImgCntr]+'.png'))
                    save_image(tensor2im(images_recon_1[trImgCntr].detach()),os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration),str(trImgCntr)+'_recon_'+labels_1[trImgCntr]+'.png'))
                    save_image(tensor2im(images_recon_2[trImgCntr].detach()),os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration),str(trImgCntr)+'_pair_'+labels_2[trImgCntr]+'.png'))
                except:
                    print('Warning while saving training image')
            
            elapsed_time = time.time() - start_time
            # for log
            
            with open(os.path.join(opt.exp_dir,opt.exp_name,'log_train.txt'), 'a') as log:
                model.eval()
                ocrModel.module.Transformation.eval()
                ocrModel.module.FeatureExtraction.eval()
                ocrModel.module.AdaptiveAvgPool.eval()
                ocrModel.module.SequenceModeling.eval()
                ocrModel.module.Prediction.eval()
                disModel.eval()
                
                with torch.no_grad():                    
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation_synth_lrw_res(
                        iteration, model, ocrModel, disModel, recCriterion, styleTestRecCriterion, ocrCriterion, valid_loader, converter, opt)
                model.train()
                if not opt.ocrFixed:
                    ocrModel.train()
                else:
                #     ocrModel.module.Transformation.eval()
                #     ocrModel.module.FeatureExtraction.eval()
                #     ocrModel.module.AdaptiveAvgPool.eval()
                    ocrModel.module.SequenceModeling.train()
                #     ocrModel.module.Prediction.eval()

                disModel.train()

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}] Train OCR loss: {loss_avg_ocr.val():0.5f}, Train Synth loss: {loss_avg.val():0.5f}, Train Dis loss: {loss_avg_dis.val():0.5f}, Valid OCR loss: {valid_loss[0]:0.5f}, Valid Synth loss: {valid_loss[1]:0.5f}, Valid Dis loss: {valid_loss[2]:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                

                current_model_log_ocr = f'{"Current_accuracy_OCR":17s}: {current_accuracy[0]:0.3f}, {"Current_norm_ED_OCR":17s}: {current_norm_ED[0]:0.2f}'
                current_model_log_1 = f'{"Current_accuracy_recon":17s}: {current_accuracy[1]:0.3f}, {"Current_norm_ED_recon":17s}: {current_norm_ED[1]:0.2f}'
                current_model_log_2 = f'{"Current_accuracy_pair":17s}: {current_accuracy[2]:0.3f}, {"Current_norm_ED_pair":17s}: {current_norm_ED[2]:0.2f}'
                
                #plotting
                lib.plot.plot(os.path.join(plotDir,'Train-OCR-Loss'), loss_avg_ocr.val().item())
                lib.plot.plot(os.path.join(plotDir,'Train-Synth-Loss'), loss_avg.val().item())
                lib.plot.plot(os.path.join(plotDir,'Train-Dis-Loss'), loss_avg_dis.val().item())
                
                lib.plot.plot(os.path.join(plotDir,'Train-OCR-Recon1-Loss'), loss_avg_ocrRecon_1.val().item())
                lib.plot.plot(os.path.join(plotDir,'Train-OCR-Recon2-Loss'), loss_avg_ocrRecon_2.val().item())
                lib.plot.plot(os.path.join(plotDir,'Train-Gen-Loss'), loss_avg_gen.val().item())
                lib.plot.plot(os.path.join(plotDir,'Train-ImgRecon1-Loss'), loss_avg_imgRecon.val().item())
                lib.plot.plot(os.path.join(plotDir,'Train-StyRecon2-Loss'), loss_avg_styRecon.val().item())

                lib.plot.plot(os.path.join(plotDir,'Valid-OCR-Loss'), valid_loss[0].item())
                lib.plot.plot(os.path.join(plotDir,'Valid-Synth-Loss'), valid_loss[1].item())
                lib.plot.plot(os.path.join(plotDir,'Valid-Dis-Loss'), valid_loss[2].item())

                lib.plot.plot(os.path.join(plotDir,'Valid-OCR-Recon1-Loss'), valid_loss[3].item())
                lib.plot.plot(os.path.join(plotDir,'Valid-OCR-Recon2-Loss'), valid_loss[4].item())
                lib.plot.plot(os.path.join(plotDir,'Valid-Gen-Loss'), valid_loss[5].item())
                lib.plot.plot(os.path.join(plotDir,'Valid-ImgRecon1-Loss'), valid_loss[6].item())
                lib.plot.plot(os.path.join(plotDir,'Valid-StyRecon2-Loss'), valid_loss[7].item())

                lib.plot.plot(os.path.join(plotDir,'Orig-OCR-WordAccuracy'), current_accuracy[0])
                lib.plot.plot(os.path.join(plotDir,'Recon-OCR-WordAccuracy'), current_accuracy[1])
                lib.plot.plot(os.path.join(plotDir,'Pair-OCR-WordAccuracy'), current_accuracy[2])

                lib.plot.plot(os.path.join(plotDir,'Orig-OCR-CharAccuracy'), current_norm_ED[0])
                lib.plot.plot(os.path.join(plotDir,'Recon-OCR-CharAccuracy'), current_norm_ED[1])
                lib.plot.plot(os.path.join(plotDir,'Pair-OCR-CharAccuracy'), current_norm_ED[2])
                

                # keep best accuracy model (on valid dataset)
                if current_accuracy[1] > best_accuracy:
                    best_accuracy = current_accuracy[1]
                    torch.save(model.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_accuracy.pth'))
                    torch.save(disModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_accuracy_dis.pth'))
                if current_norm_ED[1] > best_norm_ED:
                    best_norm_ED = current_norm_ED[1]
                    torch.save(model.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_norm_ED.pth'))
                    torch.save(disModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_norm_ED_dis.pth'))
                best_model_log = f'{"Best_accuracy_Recon":17s}: {best_accuracy:0.3f}, {"Best_norm_ED_Recon":17s}: {best_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy[0] > best_accuracy_ocr:
                    best_accuracy_ocr = current_accuracy[0]
                    if not opt.ocrFixed:
                        torch.save(ocrModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_accuracy_ocr.pth'))
                if current_norm_ED[0] > best_norm_ED_ocr:
                    best_norm_ED_ocr = current_norm_ED[0]
                    if not opt.ocrFixed:
                        torch.save(ocrModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_norm_ED_ocr.pth'))
                best_model_log_ocr = f'{"Best_accuracy_ocr":17s}: {best_accuracy_ocr:0.3f}, {"Best_norm_ED_ocr":17s}: {best_norm_ED_ocr:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log_ocr}\n{current_model_log_1}\n{current_model_log_2}\n{best_model_log_ocr}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":32s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                
                for gt_ocr, pred_ocr, confidence_ocr, gt_1, pred_1, confidence_1, gt_2, pred_2, confidence_2 in zip(labels[0][:5], preds[0][:5], confidence_score[0][:5], labels[1][:5], preds[1][:5], confidence_score[1][:5], labels[2][:5], preds[2][:5], confidence_score[2][:5]):
                    if 'Attn' in opt.Prediction:
                        # gt_ocr = gt_ocr[:gt_ocr.find('[s]')]
                        pred_ocr = pred_ocr[:pred_ocr.find('[s]')]

                        # gt_1 = gt_1[:gt_1.find('[s]')]
                        pred_1 = pred_1[:pred_1.find('[s]')]

                        # gt_2 = gt_2[:gt_2.find('[s]')]
                        pred_2 = pred_2[:pred_2.find('[s]')]

                    predicted_result_log += f'{"ocr"}: {gt_ocr:27s} | {pred_ocr:25s} | {confidence_ocr:0.4f}\t{str(pred_ocr == gt_ocr)}\n'
                    predicted_result_log += f'{"recon"}: {gt_1:25s} | {pred_1:25s} | {confidence_1:0.4f}\t{str(pred_1 == gt_1)}\n'
                    predicted_result_log += f'{"pair"}: {gt_2:26s} | {pred_2:25s} | {confidence_2:0.4f}\t{str(pred_2 == gt_2)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')

                loss_avg_ocr.reset()
                loss_avg.reset()
                loss_avg_dis.reset()

                loss_avg_ocrRecon_1.reset()
                loss_avg_ocrRecon_2.reset()
                loss_avg_gen.reset()
                loss_avg_imgRecon.reset()
                loss_avg_styRecon.reset()

            lib.plot.flush()

        lib.plot.tick()

        # save model per 1e+5 iter.
        if (iteration) % 1e+5 == 0:
            torch.save(
                model.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_synth.pth'))
            if not opt.ocrFixed:
                torch.save(
                    ocrModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_ocr.pth'))
            torch.save(
                disModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_dis.pth'))

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
Пример #3
0
def demo(opt):
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
          opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
          opt.SequenceModeling, opt.Prediction)

    model = torch.nn.DataParallel(model)
    if torch.cuda.is_available():
        model = model.cuda()

    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model))

    # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
    AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
    demo_data = RawDataset(root=opt.image_folder, opt=opt)  # use RawDataset
    demo_loader = torch.utils.data.DataLoader(
        demo_data, batch_size=opt.batch_size,
        shuffle=False,
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_demo, pin_memory=True)

    # predict
    model.eval()
    img = []
    pred_ls = []
    for image_tensors, image_path_list in demo_loader:
        batch_size = image_tensors.size(0)
        with torch.no_grad():
            image = image_tensors.cuda()
            # For max length prediction
            length_for_pred = torch.cuda.IntTensor([opt.batch_max_length] * batch_size)
            text_for_pred = torch.cuda.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0)

        if 'CTC' in opt.Prediction:
            preds = model(image, text_for_pred).log_softmax(2)

            # Select max probabilty (greedy decoding) then decode index to character
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            _, preds_index = preds.permute(1, 0, 2).max(2)
            preds_index = preds_index.transpose(1, 0).contiguous().view(-1)
            preds_str = converter.decode(preds_index.data, preds_size.data)

        else:
            preds = model(image, text_for_pred, is_train=False)

            # select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            preds_str = converter.decode(preds_index, length_for_pred)

        print('-' * 80)
        print('image_path\tpredicted_labels')
        print('-' * 80)
        for img_name, pred in zip(image_path_list, preds_str):
            if 'Attn' in opt.Prediction:
                pred = pred[:pred.find('[s]')]  # prune after "end of sentence" token ([s])
            img_name = img_name.split('/')[-1]
            pred = pred.strip()
            img.append(img_name)
            pred_ls.append(pred)
            print(f'{img_name}\t{pred}')
    print(img)
    print(pred_ls)
    submission = pd.DataFrame({'name': img, 'label': pred_ls})
    submission.to_csv('./f_submission.csv', index=False)
Пример #4
0
def demo(opt):
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    # print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
    #   opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
    #   opt.SequenceModeling, opt.Prediction)
    model = torch.nn.DataParallel(model).to(device)

    # load model
    # print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model))

    # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
    AlignCollate_demo = AlignCollate(imgH=opt.imgH,
                                     imgW=opt.imgW,
                                     keep_ratio_with_pad=opt.PAD)
    demo_data = RawDataset(root=opt.image_folder, opt=opt)  # use RawDataset
    demo_loader = torch.utils.data.DataLoader(demo_data,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=int(opt.workers),
                                              collate_fn=AlignCollate_demo,
                                              pin_memory=True)

    # predict
    model.eval()
    preds_str = []
    with torch.no_grad():
        for image_tensors, image_path_list in demo_loader:
            batch_size = image_tensors.size(0)
            image = image_tensors.to(device)
            # For max length prediction
            length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                              batch_size).to(device)
            text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                             1).fill_(0).to(device)

            if 'CTC' in opt.Prediction:
                preds = model(image, text_for_pred).log_softmax(2)

                # Select max probabilty (greedy decoding) then decode index to character
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                _, preds_index = preds.max(2)
                preds_index = preds_index.view(-1)
                preds_str = converter.decode(preds_index.data, preds_size.data)

            else:
                preds = model(image, text_for_pred, is_train=False)

                # select max probabilty (greedy decoding) then decode index to character
                _, preds_index = preds.max(2)
                preds_str = converter.decode(preds_index, length_for_pred)

            # print('-' * 80)
            # print('image_path\tpredicted_labels')
            # print('-' * 80)
            # TODO change this to put it somewhere else?
            preds_str_new = []
            for img_name, pred in zip(image_path_list, preds_str):
                if 'Attn' in opt.Prediction:
                    pred = pred[:pred.find(
                        '[s]')]  # prune after "end of sentence" token ([s])
                preds_str_new.append(pred)
                # print(f'{img_name}\t{pred}')
    return preds_str_new
Пример #5
0
def demo(opt):
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    elif 'Bert' in opt.Prediction:
        converter = TransformerConverter(opt.character, opt.batch_max_length)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)
    opt.alphabet_size = len(opt.character) + 2  # +2 for [UNK]+[EOS]

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
          opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
          opt.SequenceModeling, opt.Prediction)

    model = torch.nn.DataParallel(model)
    if torch.cuda.is_available():
        model = model.cuda()

    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model))

    # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
    AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
    demo_data = RawDataset(root=opt.image_folder, opt=opt)  # use RawDataset
    demo_loader = torch.utils.data.DataLoader(
        demo_data, batch_size=opt.batch_size,
        shuffle=False,
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_demo, pin_memory=True)

    # mkdir result
    experiment_name = os.path.join('./result', opt.image_folder.split('/')[-2])
    if not os.path.exists(experiment_name):
        os.makedirs(experiment_name)
    result = {}

    # predict
    model.eval()
    for idx, (image_tensors, image_path_list) in enumerate(demo_loader):
        batch_size = image_tensors.size(0)
        with torch.no_grad():
            image = image_tensors.cuda()
            # For max length prediction
            length_for_pred = torch.cuda.IntTensor([opt.batch_max_length] * batch_size)
            text_for_pred = torch.cuda.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0)

        if 'CTC' in opt.Prediction:
            preds = model(image, text_for_pred).log_softmax(2)

            # Select max probabilty (greedy decoding) then decode index to character
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            _, preds_index = preds.permute(1, 0, 2).max(2)
            preds_index = preds_index.transpose(1, 0).contiguous().view(-1)
            preds_str = converter.decode(preds_index.data, preds_size.data)

        elif 'Bert' in opt.Prediction:
            with torch.no_grad():
                pad_mask = None
                preds = model(image, pad_mask)

                # select max probabilty (greedy decoding) then decode index to character
                _, preds_index = preds[1].max(2)
                length_for_pred = torch.cuda.IntTensor([preds_index.size(-1)] * batch_size)
                preds_str = converter.decode(preds_index, length_for_pred)

        else:
            preds = model(image, text_for_pred, is_train=False)

            # select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            preds_str = converter.decode(preds_index, length_for_pred)

        print(f'{idx}/{len(demo_data) / opt.batch_size}')

        for img_name, pred in zip(image_path_list, preds_str):
            if 'Attn' in opt.Prediction:
                pred = pred[:pred.find('[s]')]  # prune after "end of sentence" token ([s])

            # for show

            # write in json
            name = f'{img_name}'.split('/')[-1].replace('gt', 'res').split('.')[0]
            value = [{"transcription": f'{pred}'}]
            result[name] = value

    with open(f'{experiment_name}/result.json', 'w') as f:
        json.dump(result, f)
        print("writed finish...")
Пример #6
0
def demoToTxt2(image_folder, saved_model, txtFile):  # sensitive
    parser = argparse.ArgumentParser()
    parser.add_argument('--image_folder',
                        default=image_folder,
                        help='path to image_folder which contains text images')
    parser.add_argument('--workers',
                        type=int,
                        help='number of data loading workers',
                        default=4)
    parser.add_argument('--batch_size',
                        type=int,
                        default=100,
                        help='input batch size')
    parser.add_argument('--saved_model',
                        default=saved_model,
                        help="path to saved_model to evaluation")
    """ Data processing """
    parser.add_argument('--batch_max_length',
                        type=int,
                        default=20,
                        help='maximum-label-length')
    parser.add_argument('--imgH',
                        type=int,
                        default=32,
                        help='the height of the input image')
    parser.add_argument('--imgW',
                        type=int,
                        default=100,
                        help='the width of the input image')
    parser.add_argument('--rgb', action='store_true', help='use rgb input')
    parser.add_argument('--character',
                        type=str,
                        default='0123456789',
                        help='character label')
    parser.add_argument('--sensitive',
                        default=True,
                        help='for sensitive character mode')
    parser.add_argument('--PAD',
                        default=False,
                        action='store_true',
                        help='whether to keep ratio then pad for image resize')
    """ Model Architecture """
    parser.add_argument('--Transformation',
                        default='TPS',
                        type=str,
                        help='Transformation stage. None|TPS')
    parser.add_argument('--FeatureExtraction',
                        default='ResNet',
                        type=str,
                        help='FeatureExtraction stage. VGG|RCNN|ResNet')
    parser.add_argument('--SequenceModeling',
                        default='BiLSTM',
                        type=str,
                        help='SequenceModeling stage. None|BiLSTM')
    parser.add_argument('--Prediction',
                        default='Attn',
                        type=str,
                        help='Prediction stage. CTC|Attn')
    parser.add_argument('--num_fiducial',
                        type=int,
                        default=20,
                        help='number of fiducial points of TPS-STN')
    parser.add_argument(
        '--input_channel',
        type=int,
        default=1,
        help='the number of input channel of Feature extractor')
    parser.add_argument(
        '--output_channel',
        type=int,
        default=512,
        help='the number of output channel of Feature extractor')
    parser.add_argument('--hidden_size',
                        type=int,
                        default=256,
                        help='the size of the LSTM hidden state')

    opt = parser.parse_args()
    """ vocab / character number configuration """
    if opt.sensitive:
        opt.character += 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
        # opt.character = string.printable[:-6]  # same with ASTER setting (use 94 char).

    cudnn.benchmark = True
    cudnn.deterministic = True
    opt.num_gpu = torch.cuda.device_count()
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)

    model = torch.nn.DataParallel(model)
    if torch.cuda.is_available():
        model = model.cuda()

    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model))

    # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
    AlignCollate_demo = AlignCollate(imgH=opt.imgH,
                                     imgW=opt.imgW,
                                     keep_ratio_with_pad=opt.PAD)
    demo_data = RawDataset(root=opt.image_folder, opt=opt)  # use RawDataset
    demo_loader = torch.utils.data.DataLoader(demo_data,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=int(opt.workers),
                                              collate_fn=AlignCollate_demo,
                                              pin_memory=True)

    # predict
    model.eval()
    saved_file = open(txtFile, 'w')
    for image_tensors, image_path_list in demo_loader:
        batch_size = image_tensors.size(0)
        with torch.no_grad():
            image = image_tensors.cuda()
            # For max length prediction
            length_for_pred = torch.cuda.IntTensor([opt.batch_max_length] *
                                                   batch_size)
            text_for_pred = torch.cuda.LongTensor(
                batch_size, opt.batch_max_length + 1).fill_(0)

        if 'CTC' in opt.Prediction:
            preds = model(image, text_for_pred).log_softmax(2)

            # Select max probabilty (greedy decoding) then decode index to character
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            _, preds_index = preds.permute(1, 0, 2).max(2)
            preds_index = preds_index.transpose(1, 0).contiguous().view(-1)
            preds_str = converter.decode(preds_index.data, preds_size.data)

        else:
            preds = model(image, text_for_pred, is_train=False)

            # select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            preds_str = converter.decode(preds_index, length_for_pred)

        print('-' * 80)
        print('image_path\tpredicted_labels')
        print('-' * 80)

        for img_name, pred in zip(image_path_list, preds_str):
            if 'Attn' in opt.Prediction:
                pred = pred[:pred.find(
                    '[s]')]  # prune after "end of sentence" token ([s])
            print(f'{img_name}\t{pred}')
            saved_file.write(f'{img_name}\t{pred}\n')
Пример #7
0
def demo(opt):
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)
    print(opt.num_class)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)
    model = torch.nn.DataParallel(model).to(device)

    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model, map_location=device))
    # model.load_state_dict(copy_state_dict(torch.load(opt.saved_model, map_location=device)))

    # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
    AlignCollate_demo = AlignCollate(imgH=opt.imgH,
                                     imgW=opt.imgW,
                                     keep_ratio_with_pad=opt.PAD)
    # demo_data = RawDataset(root=opt.image_folder, opt=opt)  # use RawDataset
    demo_data = LmdbDataset(root=opt.image_folder, opt=opt,
                            mode='Val')  # use RawDataset

    demo_loader = torch.utils.data.DataLoader(demo_data,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=int(opt.workers),
                                              collate_fn=AlignCollate_demo,
                                              pin_memory=True,
                                              drop_last=True)

    log = open(f'./log_demo_result.txt', 'a')
    # predict
    model.eval()
    fail_count, sample_count = 0, 0
    record_count = 1
    with torch.no_grad():
        for image_tensors, image_path_list, original_images, indexes in demo_loader:
            batch_size = image_tensors.size(0)
            image = image_tensors.to(device)
            # For max length prediction
            length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                              batch_size).to(device)
            text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                             1).fill_(0).to(device)

            if 'CTC' in opt.Prediction:
                preds = model(image, text_for_pred)

                # Select max probabilty (greedy decoding) then decode index to character
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                _, preds_index = preds.max(2)
                # preds_index = preds_index.view(-1)
                preds_str = converter.decode(preds_index, preds_size)

            else:
                preds = model(image, text_for_pred, is_train=False)

                # select max probabilty (greedy decoding) then decode index to character
                _, preds_index = preds.max(2)
                preds_str = converter.decode(preds_index, length_for_pred)

            dashed_line = '-' * 80
            # head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score'

            # print(f'{dashed_line}\n{head}\n{dashed_line}')
            # log.write(f'{dashed_line}\n{head}\n{dashed_line}\n')

            preds_prob = F.softmax(preds, dim=2)
            preds_max_prob, _ = preds_prob.max(dim=2)

            for image_tensor, gt, pred, pred_max_prob, original_image, lmdb_key in zip(
                    image_tensors, image_path_list, preds_str, preds_max_prob,
                    original_images, indexes):
                if 'Attn' in opt.Prediction:
                    pred_EOS = pred.find('[s]')
                    pred = pred[:
                                pred_EOS]  # prune after "end of sentence" token ([s])
                    pred_max_prob = pred_max_prob[:pred_EOS]

                if pred_max_prob.shape[0] > 0:
                    # calculate confidence score (= multiply of pred_max_prob)
                    confidence_score = pred_max_prob.cumprod(dim=0)[-1]
                else:
                    confidence_score = 0.0

                compare_gt = "".join(x.upper() for x in gt if x.isalnum())
                compare_pred = "".join(x.upper() for x in pred if x.isalnum())

                if compare_gt != compare_pred:
                    fail_count += 1
                    # print(f'{gt:25s}\t{pred:25s}\tFail\t{confidence_score:0.4f}\t{record_count}\n')
                    im = to_pil_image(image_tensor)
                    try:
                        # im.save(os.path.join('result', f'{lmdb_key}_{compare_pred}_{compare_gt}.jpeg'))
                        original_image.save(
                            os.path.join(
                                'result', 'fail',
                                f'{lmdb_key}_{compare_pred}_{compare_gt}.jpg'))
                    except Exception as e:
                        print(
                            f'Error: {e} {lmdb_key}_{compare_pred}_{compare_gt}'
                        )
                        exit(1)
                else:
                    # print(f'{gt:25s}\t{pred:25s}\tSuccess\t{confidence_score:0.4f}')
                    im = to_pil_image(image_tensor)
                    try:
                        # im.save(os.path.join('result', f'{lmdb_key}_{compare_pred}_{compare_gt}.jpeg'))
                        original_image.save(
                            os.path.join(
                                'result', 'success',
                                f'{lmdb_key}_{compare_pred}_{compare_gt}.jpg'))
                    except Exception as e:
                        print(
                            f'Error: {e} {lmdb_key}_{compare_pred}_{compare_gt}'
                        )
                        exit(1)

                sample_count += 1
                record_count += 1
        log.close()
        print(
            f'total accuracy: {(sample_count-fail_count)/sample_count:.2f} number of sample: {sample_count}'
        )
Пример #8
0
def demo(opt):
    inputimage = opt.input_image
    boxesscv = opt.boxescsv
    bboxes = parse_csv(inputimage, boxesscv)
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)
    model = torch.nn.DataParallel(model).to(device)

    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model, map_location=device))

    # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
    AlignCollate_demo = AlignCollate(imgH=opt.imgH,
                                     imgW=opt.imgW,
                                     keep_ratio_with_pad=opt.PAD)
    demo_data = RawDataset(root=opt.image_folder, opt=opt)  # use RawDataset
    demo_loader = torch.utils.data.DataLoader(demo_data,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=int(opt.workers),
                                              collate_fn=AlignCollate_demo,
                                              pin_memory=True)

    # predict
    model.eval()
    with torch.no_grad():
        for image_tensors, image_path_list in demo_loader:
            batch_size = image_tensors.size(0)
            image = image_tensors.to(device)
            # For max length prediction
            length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                              batch_size).to(device)
            text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                             1).fill_(0).to(device)

            if 'CTC' in opt.Prediction:
                preds = model(image, text_for_pred)

                # Select max probabilty (greedy decoding) then decode index to character
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                _, preds_index = preds.max(2)
                # preds_index = preds_index.view(-1)
                preds_str = converter.decode(preds_index, preds_size)

            else:
                preds = model(image, text_for_pred, is_train=False)

                # select max probabilty (greedy decoding) then decode index to character
                _, preds_index = preds.max(2)
                preds_str = converter.decode(preds_index, length_for_pred)

            log = open(f'{opt.output_folder}result.csv', 'w')

            preds_prob = F.softmax(preds, dim=2)
            preds_max_prob, _ = preds_prob.max(dim=2)
            for img_index, (pred, pred_max_prob) in enumerate(
                    zip(preds_str, preds_max_prob)):
                if 'Attn' in opt.Prediction:
                    pred_EOS = pred.find('[s]')
                    pred = pred[:
                                pred_EOS]  # prune after "end of sentence" token ([s])
                    pred_max_prob = pred_max_prob[:pred_EOS]

                # calculate confidence score (= multiply of pred_max_prob)
                confidence_score = pred_max_prob.cumprod(dim=0)[-1]
                for pts in bboxes[img_index]:
                    x, y = pts
                    log.write(f'{x},{y},')
                log.write(f'{pred}\n')

            log.close()
            # copy log to local output folder
            os.system(f'cp {opt.output_folder}result.csv /input/output')
            shutil.make_archive('per_word_visual', 'zip', '/input/output')
Пример #9
0
def _textRecognition(opt):
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)
    model = torch.nn.DataParallel(model).to(device)

    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model, map_location=device))

    # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
    AlignCollate_demo = AlignCollate(imgH=opt.imgH,
                                     imgW=opt.imgW,
                                     keep_ratio_with_pad=opt.PAD)
    demo_data = RawDataset(root=opt.image_folder, opt=opt)  # use RawDataset
    demo_loader = torch.utils.data.DataLoader(demo_data,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=int(opt.workers),
                                              collate_fn=AlignCollate_demo,
                                              pin_memory=True)

    # predict
    char_list = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"

    csv_filename = os.path.join(opt.output_dirpath, "text_information.csv")
    Header = ["bookID", "prediction"]

    with open(csv_filename, 'w') as f:
        writer = csv.DictWriter(f, fieldnames=Header)
        writer.writeheader()

        model.eval()
        with torch.no_grad():
            for image_tensors, image_path_list in demo_loader:
                batch_size = image_tensors.size(0)
                image = image_tensors.to(device)
                # For max length prediction
                length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                                  batch_size).to(device)
                text_for_pred = torch.LongTensor(
                    batch_size, opt.batch_max_length + 1).fill_(0).to(device)

                if 'CTC' in opt.Prediction:
                    preds = model(image, text_for_pred)

                    # Select max probabilty (greedy decoding) then decode index to character
                    preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                    _, preds_index = preds.max(2)
                    preds_index = preds_index.view(-1)
                    preds_str = converter.decode(preds_index.data,
                                                 preds_size.data)

                else:
                    preds = model(image, text_for_pred, is_train=False)

                    # select max probabilty (greedy decoding) then decode index to character
                    _, preds_index = preds.max(2)
                    preds_str = converter.decode(preds_index, length_for_pred)

                #log = open(f'./text_information.csv', 'a')
                dashed_line = '-' * 80
                head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score'

                print(f'{dashed_line}\n{head}\n{dashed_line}')

                preds_prob = F.softmax(preds, dim=2)
                preds_max_prob, _ = preds_prob.max(dim=2)
                for img_name, pred, pred_max_prob in zip(
                        image_path_list, preds_str, preds_max_prob):
                    if 'Attn' in opt.Prediction:
                        pred_EOS = pred.find('[s]')
                        pred = pred[:
                                    pred_EOS]  # prune after "end of sentence" token ([s])
                        pred_max_prob = pred_max_prob[:pred_EOS]

                    # calculate confidence score (= multiply of pred_max_prob)
                    #print("{}:{}".format(pred_max_prob,len(pred_max_prob)))
                    try:
                        confidence_score = pred_max_prob.cumprod(dim=0)[-1]
                    except:
                        deleteImageAndText(opt.book_img_dirpath, img_name)
                        continue
                    #confidence_score = pred_max_prob.cumprod(dim=0)[-1]
                    pred = pred[0]

                    if confidence_score < 0.5:
                        pred = "Unreadble"
                        deleteImageAndText(opt.book_img_dirpath, img_name)
                        continue

                    elif not (pred in char_list):
                        pred = "Undefined"

                    # extract the name part of the image
                    filename = os.path.basename(img_name)

                    print(
                        f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}')

                    writer.writerow({"bookID": filename, "prediction": pred})
 print("Starting text classification")
 model.eval()
 with torch.no_grad():
     for image_tensors, image_path_list in demo_loader:
         batch_size = image_tensors.size(0)
         image = image_tensors.to(device)
         #image = (torch.from_numpy(crop_image).unsqueeze(0)).to(device)
         #print(image_path_list)
         #print(image.size())
         length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                           batch_size).to(device)
         text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                          1).fill_(0).to(device)
         preds = model(image, text_for_pred, is_train=False)
         _, preds_index = preds.max(2)
         preds_str = converter.decode(preds_index, length_for_pred)
         #print(preds_str)
         output_string = ""
         curr_order = 1
         for path, p in zip(image_path_list, preds_str):
             if 'Attn' in opt.Prediction:
                 pred_EOS = p.find('[s]')
                 p = p[:
                       pred_EOS]  # prune after "end of sentence" token ([s])
             path_info = path[len(result_folder
                                  ):-len(file_extension)].split("_")
             if (path_info[0] == filename):
                 if (int(path_info[1]) > curr_order):
                     curr_order += 1
                     output_string += "\n"
                 output_string += p + " "