def recognition(image):

    converter = AttnLabelConverter(args.character)
    args.num_class = len(converter.character)
    if args.rgb:
        args.input_channel = 3
    model = Model(args)

    model = torch.nn.DataParallel(model).to(device)
    model.load_state_dict(torch.load(args.model_dir, map_location=device))

    transformer = resizeNormalize((100, 32))

    #Convert RGB images to Gray Scale which is neccesary for our convolution layers.
    if image.mode == 'RGB':
        image = image.convert('L')

    image = transformer(image)
    batch_size = image.size(0)
    if torch.cuda.is_available():
        image = image.cuda()

    image = image.view(1, *image.size())
    image = Variable(image)

    model.eval()

    length_for_pred = torch.IntTensor([args.batch_max_length] *
                                      batch_size).to(device)
    text_for_pred = torch.LongTensor(batch_size, args.batch_max_length +
                                     1).fill_(0).to(device)
    preds = model(image, text_for_pred, is_train=False)

    _, preds_index = preds.max(2)

    preds_str = converter.decode(preds_index, length_for_pred)
    preds_prob = F.softmax(preds, dim=2)

    text_prediction = preds_str[0].replace("[s]", "")

    return text_prediction
Ejemplo n.º 2
0
class PytorchNet:
    def __init__(self, model_path, gpu_id=None):
        """
        初始化模型
        :param model_path: 模型地址
        :param gpu_id: 在哪一块gpu上运行
        """
        checkpoint = torch.load(model_path)
        print(f"load {checkpoint['epoch']} epoch params")
        config = checkpoint['config']
        alphabet = config['dataset']['alphabet']
        if gpu_id is not None and isinstance(
                gpu_id, int) and torch.cuda.is_available():
            self.device = torch.device("cuda:%s" % gpu_id)
        else:
            self.device = torch.device("cpu")
        print('device:', self.device)

        self.transform = []
        for t in config['dataset']['train']['dataset']['args']['transforms']:
            if t['type'] in ['ToTensor', 'Normalize']:
                self.transform.append(t)
        self.transform = get_transforms(self.transform)

        self.gpu_id = gpu_id
        img_h, img_w = 32, 100
        for process in config['dataset']['train']['dataset']['args'][
                'pre_processes']:
            if process['type'] == "Resize":
                img_h = process['args']['img_h']
                img_w = process['args']['img_w']
                break
        self.img_w = img_w
        self.img_h = img_h
        self.img_mode = config['dataset']['train']['dataset']['args'][
            'img_mode']
        self.alphabet = alphabet
        img_channel = 3 if config['dataset']['train']['dataset']['args'][
            'img_mode'] != 'GRAY' else 1

        if config['arch']['args']['prediction']['type'] == 'CTC':
            self.converter = CTCLabelConverter(config['dataset']['alphabet'])
        elif config['arch']['args']['prediction']['type'] == 'Attn':
            self.converter = AttnLabelConverter(config['dataset']['alphabet'])
        self.net = get_model(img_channel, len(self.converter.character),
                             config['arch']['args'])
        self.net.load_state_dict(checkpoint['state_dict'])
        # self.net = torch.jit.load('crnn_lite_gpu.pt')
        self.net.to(self.device)
        self.net.eval()
        sample_input = torch.zeros(
            (2, img_channel, img_h, img_w)).to(self.device)
        self.net.get_batch_max_length(sample_input)

    def predict(self, img_path, model_save_path=None):
        """
        对传入的图像进行预测,支持图像地址和numpy数组
        :param img_path: 图像地址
        :return:
        """
        assert os.path.exists(img_path), 'file is not exists'
        img = self.pre_processing(img_path)
        tensor = self.transform(img)
        tensor = tensor.unsqueeze(dim=0)

        tensor = tensor.to(self.device)
        preds, tensor_img = self.net(tensor)

        preds = preds.softmax(dim=2).detach().cpu().numpy()
        # result = decode(preds, self.alphabet, raw=True)
        # print(result)
        result = self.converter.decode(preds)
        if model_save_path is not None:
            # 输出用于部署的模型
            save(self.net, tensor, model_save_path)
        return result, tensor_img

    def pre_processing(self, img_path):
        """
        对图片进行处理,先按照高度进行resize,resize之后如果宽度不足指定宽度,就补黑色像素,否则就强行缩放到指定宽度
        :param img_path: 图片地址
        :return:
        """
        img = cv2.imread(img_path, 1 if self.img_mode != 'GRAY' else 0)
        if self.img_mode == 'RGB':
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        h, w = img.shape[:2]
        ratio_h = float(self.img_h) / h
        new_w = int(w * ratio_h)

        if new_w < self.img_w:
            img = cv2.resize(img, (new_w, self.img_h))
            step = np.zeros((self.img_h, self.img_w - new_w, img.shape[-1]),
                            dtype=img.dtype)
            img = np.column_stack((img, step))
        else:
            img = cv2.resize(img, (self.img_w, self.img_h))
        return img
Ejemplo n.º 3
0
def demo(opt):
    """ model configuration """
    lists = []  #목적지라고 생각하는 사진에서 인식한 text를 담을 배열

    converter = AttnLabelConverter(opt.character)  #ATTN

    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3

    model = Model(opt)  #model.py의 Model import

    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling,
          opt.Prediction)  #파라미터 값 정보 출력

    model = torch.nn.DataParallel(model).to(device)  #GPU로 데이터 병렬 처리 진행

    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model,
                                     map_location=device))  #모델의 매개변수를 불러옴

    AlignCollate_demo = AlignCollate(imgH=opt.imgH,
                                     imgW=opt.imgW,
                                     keep_ratio_with_pad=opt.PAD)
    demo_data1 = RawDataset(root=opt.image_folder1,
                            opt=opt)  # use RawDataset 간판탐지결과
    demo_data2 = RawDataset(root=opt.image_folder2,
                            opt=opt)  # use RawDataset 구글맵문자열탐지결과

    demo_loader1 = torch.utils.data.DataLoader(demo_data1,
                                               batch_size=opt.batch_size,
                                               shuffle=False,
                                               num_workers=int(opt.workers),
                                               collate_fn=AlignCollate_demo,
                                               pin_memory=True)
    demo_loader2 = torch.utils.data.DataLoader(demo_data2,
                                               batch_size=opt.batch_size,
                                               shuffle=False,
                                               num_workers=int(opt.workers),
                                               collate_fn=AlignCollate_demo,
                                               pin_memory=True)

    # predict
    model.eval()
    with torch.no_grad():
        for image_tensors, image_path_list in demo_loader1:
            batch_size = image_tensors.size(0)
            image = image_tensors.to(device)

            # For max length prediction
            length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                              batch_size).to(device)
            text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                             1).fill_(0).to(device)

            #ATTn
            preds = model(image, text_for_pred, is_train=False)

            # select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            preds_str = converter.decode(preds_index, length_for_pred)

            log = open(f'./log_demo_result.txt', 'a')  #이어서 쓸수 있게 열고
            dashed_line = '-' * 80
            head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score'

            print(f'{dashed_line}\n{head}\n{dashed_line}')  #테이블 양식 출력
            log.write(
                f'{dashed_line}\n{head}\n{dashed_line}\n')  #txt에 테이블 양식 저장

            preds_prob = F.softmax(preds, dim=2)
            preds_max_prob, _ = preds_prob.max(dim=2)
            for img_name, pred, pred_max_prob in zip(image_path_list,
                                                     preds_str,
                                                     preds_max_prob):
                pred_EOS = pred.find('[s]')
                pred = pred[:
                            pred_EOS]  # prune after "end of sentence" token ([s])
                pred_max_prob = pred_max_prob[:pred_EOS]

                # calculate confidence score (= multiply of pred_max_prob) confidence score 값을 계산
                confidence_score = pred_max_prob.cumprod(dim=0)[-1]

                lists.append(pred)
                print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}'
                      )  #구한 값을 출력
                log.write(
                    f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n'
                )  #구한 값을 txt에 저장

            log.close()  #파일 닫기

    with torch.no_grad():
        for image_tensors, image_path_list in demo_loader2:
            batch_size = image_tensors.size(0)
            image = image_tensors.to(device)
            # For max length prediction
            length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                              batch_size).to(device)
            text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                             1).fill_(0).to(device)

            #ATTn
            preds = model(image, text_for_pred, is_train=False)

            # select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            preds_str = converter.decode(preds_index, length_for_pred)

            log = open(f'./log_demo_result.txt', 'a')  #이어서 쓸수 있게 열고
            dashed_line = '-' * 80
            head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score'

            print(f'{dashed_line}\n{head}\n{dashed_line}')  #테이블 양식 출력
            log.write(
                f'{dashed_line}\n{head}\n{dashed_line}\n')  #txt에 테이블 양식 저장

            preds_prob = F.softmax(preds, dim=2)
            preds_max_prob, _ = preds_prob.max(dim=2)
            for img_name, pred, pred_max_prob in zip(image_path_list,
                                                     preds_str,
                                                     preds_max_prob):
                pred_EOS = pred.find('[s]')
                pred = pred[:
                            pred_EOS]  # prune after "end of sentence" token ([s])
                pred_max_prob = pred_max_prob[:pred_EOS]

                # confidence score 값을 계산
                confidence_score = pred_max_prob.cumprod(dim=0)[-1]

                print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}'
                      )  #구한 값을 출력
                log.write(
                    f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n'
                )  #구한 값을 txt에 저장
                if pred in lists:
                    print(pred + "은(는) 알맞은 목적지입니다.")
                else:
                    print(pred + "은(는) 알맞은 목적지가 아닙니다.")

            log.close()  #파일 닫기
def train(opt):
    lib.print_model_settings(locals().copy())

    if 'Attn' in opt.Prediction:
        converter = AttnLabelConverter(opt.character)
        text_len = opt.batch_max_length+2
    else:
        converter = CTCLabelConverter(opt.character)
        text_len = opt.batch_max_length

    opt.classes = converter.character
    
    """ dataset preparation """
    if not opt.data_filtering_off:
        print('Filtering the images containing characters which are not in opt.character')
        print('Filtering the images whose label is longer than opt.batch_max_length')
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    log = open(os.path.join(opt.exp_dir,opt.exp_name,'log_dataset.txt'), 'a')
    AlignCollate_valid = AlignPairCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)

    train_dataset = LmdbStyleDataset(root=opt.train_data, opt=opt)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=opt.batch_size*2, #*2 to sample different images from training encoder and discriminator real images
        shuffle=True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True)
    
    print('-' * 80)
    
    valid_dataset = LmdbStyleDataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=opt.batch_size*2, #*2 to sample different images from training encoder and discriminator real images
        shuffle=False,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True)
    
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()

    text_dataset = text_gen(opt)
    text_loader = torch.utils.data.DataLoader(
        text_dataset, batch_size=opt.batch_size,
        shuffle=True,
        num_workers=int(opt.workers),
        pin_memory=True, drop_last=True)
    opt.num_class = len(converter.character)
    

    c_code_size = opt.latent
    cEncoder = GlobalContentEncoder(opt.num_class, text_len, opt.char_embed_size, c_code_size)
    ocrModel = ModelV1(opt)

    
    genModel = styleGANGen(opt.size, opt.latent, opt.latent, opt.n_mlp, channel_multiplier=opt.channel_multiplier)
    g_ema = styleGANGen(opt.size, opt.latent, opt.latent, opt.n_mlp, channel_multiplier=opt.channel_multiplier)
   
    disEncModel = styleGANDis(opt.size, channel_multiplier=opt.channel_multiplier, input_dim=opt.input_channel, code_s_dim=c_code_size)
    
    accumulate(g_ema, genModel, 0)
    
    # uCriterion = torch.nn.MSELoss()
    # sCriterion = torch.nn.MSELoss()
    # if opt.contentLoss == 'vis' or opt.contentLoss == 'seq':
    #     ocrCriterion = torch.nn.L1Loss()
    # else:
    if 'CTC' in opt.Prediction:
        ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        print('Not implemented error')
        sys.exit()
        # ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)  # ignore [GO] token = ignore index 0

    cEncoder= torch.nn.DataParallel(cEncoder).to(device)
    cEncoder.train()
    genModel = torch.nn.DataParallel(genModel).to(device)
    g_ema = torch.nn.DataParallel(g_ema).to(device)
    genModel.train()
    g_ema.eval()

    disEncModel = torch.nn.DataParallel(disEncModel).to(device)
    disEncModel.train()

    ocrModel = torch.nn.DataParallel(ocrModel).to(device)
    if opt.ocrFixed:
        if opt.Transformation == 'TPS':
            ocrModel.module.Transformation.eval()
        ocrModel.module.FeatureExtraction.eval()
        ocrModel.module.AdaptiveAvgPool.eval()
        # ocrModel.module.SequenceModeling.eval()
        ocrModel.module.Prediction.eval()
    else:
        ocrModel.train()

    g_reg_ratio = opt.g_reg_every / (opt.g_reg_every + 1)
    d_reg_ratio = opt.d_reg_every / (opt.d_reg_every + 1)

    
    optimizer = optim.Adam(
        list(genModel.parameters())+list(cEncoder.parameters()),
        lr=opt.lr * g_reg_ratio,
        betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
    )
    dis_optimizer = optim.Adam(
        disEncModel.parameters(),
        lr=opt.lr * d_reg_ratio,
        betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
    )
    
    ocr_optimizer = optim.Adam(
        ocrModel.parameters(),
        lr=opt.lr,
        betas=(0.9, 0.99),
    )


    ## Loading pre-trained files
    if opt.modelFolderFlag:
        if len(glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth")))>0:
            opt.saved_synth_model = glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth"))[-1]

    if opt.saved_ocr_model !='' and opt.saved_ocr_model !='None':
        print(f'loading pretrained ocr model from {opt.saved_ocr_model}')
        checkpoint = torch.load(opt.saved_ocr_model)
        ocrModel.load_state_dict(checkpoint)
    
    # if opt.saved_gen_model !='' and opt.saved_gen_model !='None':
    #     print(f'loading pretrained gen model from {opt.saved_gen_model}')
    #     checkpoint = torch.load(opt.saved_gen_model, map_location=lambda storage, loc: storage)
    #     genModel.module.load_state_dict(checkpoint['g'])
    #     g_ema.module.load_state_dict(checkpoint['g_ema'])

    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        print(f'loading pretrained synth model from {opt.saved_synth_model}')
        checkpoint = torch.load(opt.saved_synth_model)
        
        # styleModel.load_state_dict(checkpoint['styleModel'])
        # mixModel.load_state_dict(checkpoint['mixModel'])
        genModel.load_state_dict(checkpoint['genModel'])
        g_ema.load_state_dict(checkpoint['g_ema'])
        disEncModel.load_state_dict(checkpoint['disEncModel'])
        ocrModel.load_state_dict(checkpoint['ocrModel'])
        
        optimizer.load_state_dict(checkpoint["optimizer"])
        dis_optimizer.load_state_dict(checkpoint["dis_optimizer"])
        ocr_optimizer.load_state_dict(checkpoint["ocr_optimizer"])

    # if opt.imgReconLoss == 'l1':
    #     recCriterion = torch.nn.L1Loss()
    # elif opt.imgReconLoss == 'ssim':
    #     recCriterion = ssim
    # elif opt.imgReconLoss == 'ms-ssim':
    #     recCriterion = msssim
    

    # loss averager
    loss_avg_dis = Averager()
    loss_avg_gen = Averager()
    loss_avg_unsup = Averager()
    loss_avg_sup = Averager()
    log_r1_val = Averager()
    log_avg_path_loss_val = Averager()
    log_avg_mean_path_length_avg = Averager()
    log_ada_aug_p = Averager()
    loss_avg_ocr_sup = Averager()
    loss_avg_ocr_unsup = Averager()

    """ final options """
    with open(os.path.join(opt.exp_dir,opt.exp_name,'opt.txt'), 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)

    """ start training """
    start_iter = 0
    
    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        try:
            start_iter = int(opt.saved_synth_model.split('_')[-2].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    
    #get schedulers
    scheduler = get_scheduler(optimizer,opt)
    dis_scheduler = get_scheduler(dis_optimizer,opt)
    ocr_scheduler = get_scheduler(ocr_optimizer,opt)

    start_time = time.time()
    iteration = start_iter
    cntr=0
    
    mean_path_length = 0
    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    # loss_dict = {}

    accum = 0.5 ** (32 / (10 * 1000))
    ada_augment = torch.tensor([0.0, 0.0], device=device)
    ada_aug_p = opt.augment_p if opt.augment_p > 0 else 0.0
    ada_aug_step = opt.ada_target / opt.ada_length
    r_t_stat = 0
    epsilon = 10e-50
    # sample_z = torch.randn(opt.n_sample, opt.latent, device=device)

    while(True):
        # print(cntr)
        # train part
        if opt.lr_policy !="None":
            scheduler.step()
            dis_scheduler.step()
            ocr_scheduler.step()
        
        image_input_tensors, _, labels, _ = iter(train_loader).next()
        labels_z_c = iter(text_loader).next()

        image_input_tensors = image_input_tensors.to(device)
        gt_image_tensors = image_input_tensors[:opt.batch_size].detach()
        real_image_tensors = image_input_tensors[opt.batch_size:].detach()
        
        labels_gt = labels[:opt.batch_size]
        
        requires_grad(cEncoder, False)
        requires_grad(genModel, False)
        requires_grad(disEncModel, True)
        requires_grad(ocrModel, False)

        text_z_c, length_z_c = converter.encode(labels_z_c, batch_max_length=opt.batch_max_length)
        text_gt, length_gt = converter.encode(labels_gt, batch_max_length=opt.batch_max_length)

        z_c_code = cEncoder(text_z_c)
        noise_style = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device)
        style=[]
        style.append(noise_style[0]*z_c_code)
        if len(noise_style)>1:
            style.append(noise_style[1]*z_c_code)
        
        if opt.zAlone:
            #to validate orig style gan results
            newstyle = []
            newstyle.append(style[0][:,:opt.latent])
            if len(style)>1:
                newstyle.append(style[1][:,:opt.latent])
            style = newstyle
        
        fake_img,_ = genModel(style, input_is_latent=opt.input_latent)
        
        # #unsupervised code prediction on generated image
        # u_pred_code = disEncModel(fake_img, mode='enc')
        # uCost = uCriterion(u_pred_code, z_code)

        # #supervised code prediction on gt image
        # s_pred_code = disEncModel(gt_image_tensors, mode='enc')
        # sCost = uCriterion(s_pred_code, gt_phoc_tensors)

        #Domain discriminator
        fake_pred = disEncModel(fake_img)
        real_pred = disEncModel(real_image_tensors)
        disCost = d_logistic_loss(real_pred, fake_pred)

        # dis_cost = disCost + opt.gamma_e*uCost + opt.beta*sCost
        loss_avg_dis.add(disCost)
        # loss_avg_sup.add(opt.beta*sCost)
        # loss_avg_unsup.add(opt.gamma_e * uCost)

        disEncModel.zero_grad()
        disCost.backward()
        dis_optimizer.step()

        d_regularize = cntr % opt.d_reg_every == 0

        if d_regularize:
            real_image_tensors.requires_grad = True
            real_pred = disEncModel(real_image_tensors)
            
            r1_loss = d_r1_loss(real_pred, real_image_tensors)

            disEncModel.zero_grad()
            (opt.r1 / 2 * r1_loss * opt.d_reg_every + 0 * real_pred[0]).backward()

            dis_optimizer.step()
        log_r1_val.add(r1_loss)
        
        # Recognizer update
        if not opt.ocrFixed and not opt.zAlone:
            requires_grad(disEncModel, False)
            requires_grad(ocrModel, True)

            if 'CTC' in opt.Prediction:
                preds_recon = ocrModel(gt_image_tensors, text_gt, is_train=True)
                preds_size = torch.IntTensor([preds_recon.size(1)] * opt.batch_size)
                preds_recon_softmax = preds_recon.log_softmax(2).permute(1, 0, 2)
                ocrCost = ocrCriterion(preds_recon_softmax, text_gt, preds_size, length_gt)
            else:
                print("Not implemented error")
                sys.exit()
            
            ocrModel.zero_grad()
            ocrCost.backward()
            # torch.nn.utils.clip_grad_norm_(ocrModel.parameters(), opt.grad_clip)  # gradient clipping with 5 (Default)
            ocr_optimizer.step()
            loss_avg_ocr_sup.add(ocrCost)
        else:
            loss_avg_ocr_sup.add(torch.tensor(0.0))


        # [Word Generator] update
        # image_input_tensors, _, labels, _ = iter(train_loader).next()
        labels_z_c = iter(text_loader).next()

        # image_input_tensors = image_input_tensors.to(device)
        # gt_image_tensors = image_input_tensors[:opt.batch_size]
        # real_image_tensors = image_input_tensors[opt.batch_size:]
        
        # labels_gt = labels[:opt.batch_size]

        requires_grad(cEncoder, True)
        requires_grad(genModel, True)
        requires_grad(disEncModel, False)
        requires_grad(ocrModel, False)

        text_z_c, length_z_c = converter.encode(labels_z_c, batch_max_length=opt.batch_max_length)
        
        z_c_code = cEncoder(text_z_c)
        noise_style = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device)
        style=[]
        style.append(noise_style[0]*z_c_code)
        if len(noise_style)>1:
            style.append(noise_style[1]*z_c_code)

        if opt.zAlone:
            #to validate orig style gan results
            newstyle = []
            newstyle.append(style[0][:,:opt.latent])
            if len(style)>1:
                newstyle.append(style[1][:,:opt.latent])
            style = newstyle
        
        fake_img,_ = genModel(style, input_is_latent=opt.input_latent)

        fake_pred = disEncModel(fake_img)
        disGenCost = g_nonsaturating_loss(fake_pred)

        if opt.zAlone:
            ocrCost = torch.tensor(0.0)
        else:
            #Compute OCR prediction (Reconstruction of content)
            # text_for_pred = torch.LongTensor(opt.batch_size, opt.batch_max_length + 1).fill_(0).to(device)
            # length_for_pred = torch.IntTensor([opt.batch_max_length] * opt.batch_size).to(device)
            
            if 'CTC' in opt.Prediction:
                preds_recon = ocrModel(fake_img, text_z_c, is_train=False)
                preds_size = torch.IntTensor([preds_recon.size(1)] * opt.batch_size)
                preds_recon_softmax = preds_recon.log_softmax(2).permute(1, 0, 2)
                ocrCost = ocrCriterion(preds_recon_softmax, text_z_c, preds_size, length_z_c)
            else:
                print("Not implemented error")
                sys.exit()
        
        genModel.zero_grad()
        cEncoder.zero_grad()

        gen_enc_cost = disGenCost + opt.ocrWeight * ocrCost
        grad_fake_OCR = torch.autograd.grad(ocrCost, fake_img, retain_graph=True)[0]
        loss_grad_fake_OCR = 10**6*torch.mean(grad_fake_OCR**2)
        grad_fake_adv = torch.autograd.grad(disGenCost, fake_img, retain_graph=True)[0]
        loss_grad_fake_adv = 10**6*torch.mean(grad_fake_adv**2)
        
        if opt.grad_balance:
            gen_enc_cost.backward(retain_graph=True)
            grad_fake_OCR = torch.autograd.grad(ocrCost, fake_img, create_graph=True, retain_graph=True)[0]
            grad_fake_adv = torch.autograd.grad(disGenCost, fake_img, create_graph=True, retain_graph=True)[0]
            a = opt.ocrWeight * torch.div(torch.std(grad_fake_adv), epsilon+torch.std(grad_fake_OCR))
            if a is None:
                print(ocrCost, disGenCost, torch.std(grad_fake_adv), torch.std(grad_fake_OCR))
            if a>1000 or a<0.0001:
                print(a)
            
            ocrCost = a.detach() * ocrCost
            gen_enc_cost = disGenCost + ocrCost
            gen_enc_cost.backward(retain_graph=True)
            grad_fake_OCR = torch.autograd.grad(ocrCost, fake_img, create_graph=False, retain_graph=True)[0]
            grad_fake_adv = torch.autograd.grad(disGenCost, fake_img, create_graph=False, retain_graph=True)[0]
            loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2)
            loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
            with torch.no_grad():
                gen_enc_cost.backward()
        else:
            gen_enc_cost.backward()

        loss_avg_gen.add(disGenCost)
        loss_avg_ocr_unsup.add(opt.ocrWeight * ocrCost)

        optimizer.step()
        
        g_regularize = cntr % opt.g_reg_every == 0

        if g_regularize:
            path_batch_size = max(1, opt.batch_size // opt.path_batch_shrink)
            # image_input_tensors, _, labels, _ = iter(train_loader).next()
            labels_z_c = iter(text_loader).next()

            # image_input_tensors = image_input_tensors.to(device)
            # gt_image_tensors = image_input_tensors[:path_batch_size]

            # labels_gt = labels[:path_batch_size]

            text_z_c, length_z_c = converter.encode(labels_z_c[:path_batch_size], batch_max_length=opt.batch_max_length)
            # text_gt, length_gt = converter.encode(labels_gt, batch_max_length=opt.batch_max_length)
        
            z_c_code = cEncoder(text_z_c)
            noise_style = mixing_noise_style(path_batch_size, opt.latent, opt.mixing, device)
            style=[]
            style.append(noise_style[0]*z_c_code)
            if len(noise_style)>1:
                style.append(noise_style[1]*z_c_code)

            if opt.zAlone:
                #to validate orig style gan results
                newstyle = []
                newstyle.append(style[0][:,:opt.latent])
                if len(style)>1:
                    newstyle.append(style[1][:,:opt.latent])
                style = newstyle

            fake_img, grad = genModel(style, return_latents=True, g_path_regularize=True, mean_path_length=mean_path_length)
            
            decay = 0.01
            path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))

            mean_path_length_orig = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
            path_loss = (path_lengths - mean_path_length_orig).pow(2).mean()
            mean_path_length = mean_path_length_orig.detach().item()

            genModel.zero_grad()
            cEncoder.zero_grad()
            weighted_path_loss = opt.path_regularize * opt.g_reg_every * path_loss

            if opt.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            optimizer.step()

            # mean_path_length_avg = (
            #     reduce_sum(mean_path_length).item() / get_world_size()
            # )
            #commented above for multi-gpu , non-distributed setting
            mean_path_length_avg = mean_path_length

        accumulate(g_ema, genModel, accum)

        log_avg_path_loss_val.add(path_loss)
        log_avg_mean_path_length_avg.add(torch.tensor(mean_path_length_avg))
        log_ada_aug_p.add(torch.tensor(ada_aug_p))
        

        if get_rank() == 0:
            if wandb and opt.wandb:
                wandb.log(
                    {
                        "Generator": g_loss_val,
                        "Discriminator": d_loss_val,
                        "Augment": ada_aug_p,
                        "Rt": r_t_stat,
                        "R1": r1_val,
                        "Path Length Regularization": path_loss_val,
                        "Mean Path Length": mean_path_length,
                        "Real Score": real_score_val,
                        "Fake Score": fake_score_val,   
                        "Path Length": path_length_val,
                    }
                )
        
        # validation part
        if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' 
            
            #generate paired content with similar style
            labels_z_c_1 = iter(text_loader).next()
            labels_z_c_2 = iter(text_loader).next()
            
            text_z_c_1, length_z_c_1 = converter.encode(labels_z_c_1, batch_max_length=opt.batch_max_length)
            text_z_c_2, length_z_c_2 = converter.encode(labels_z_c_2, batch_max_length=opt.batch_max_length)

            z_c_code_1 = cEncoder(text_z_c_1)
            z_c_code_2 = cEncoder(text_z_c_2)

            
            style_c1_s1 = []
            style_c2_s1 = []
            style_s1 = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device)
            style_c1_s1.append(style_s1[0]*z_c_code_1)
            style_c2_s1.append(style_s1[0]*z_c_code_2)
            if len(style_s1)>1:
                style_c1_s1.append(style_s1[1]*z_c_code_1)
                style_c2_s1.append(style_s1[1]*z_c_code_2)
            
            noise_style = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device)
            style_c1_s2 = []
            style_c1_s2.append(noise_style[0]*z_c_code_1)
            if len(noise_style)>1:
                style_c1_s2.append(noise_style[1]*z_c_code_1)
            
            if opt.zAlone:
                #to validate orig style gan results
                newstyle = []
                newstyle.append(style_c1_s1[0][:,:opt.latent])
                if len(style_c1_s1)>1:
                    newstyle.append(style_c1_s1[1][:,:opt.latent])
                style_c1_s1 = newstyle
                style_c2_s1 = newstyle
                style_c1_s2 = newstyle
            
            fake_img_c1_s1, _ = g_ema(style_c1_s1, input_is_latent=opt.input_latent)
            fake_img_c2_s1, _ = g_ema(style_c2_s1, input_is_latent=opt.input_latent)
            fake_img_c1_s2, _ = g_ema(style_c1_s2, input_is_latent=opt.input_latent)

            if not opt.zAlone:
                #Run OCR prediction
                if 'CTC' in opt.Prediction:
                    preds = ocrModel(fake_img_c1_s1, text_z_c_1, is_train=False)
                    preds_size = torch.IntTensor([preds.size(1)] * opt.batch_size)
                    _, preds_index = preds.max(2)
                    preds_str_fake_img_c1_s1 = converter.decode(preds_index.data, preds_size.data)

                    preds = ocrModel(fake_img_c2_s1, text_z_c_2, is_train=False)
                    preds_size = torch.IntTensor([preds.size(1)] * opt.batch_size)
                    _, preds_index = preds.max(2)
                    preds_str_fake_img_c2_s1 = converter.decode(preds_index.data, preds_size.data)

                    preds = ocrModel(fake_img_c1_s2, text_z_c_1, is_train=False)
                    preds_size = torch.IntTensor([preds.size(1)] * opt.batch_size)
                    _, preds_index = preds.max(2)
                    preds_str_fake_img_c1_s2 = converter.decode(preds_index.data, preds_size.data)

                    preds = ocrModel(gt_image_tensors, text_gt, is_train=False)
                    preds_size = torch.IntTensor([preds.size(1)] * gt_image_tensors.shape[0])
                    _, preds_index = preds.max(2)
                    preds_str_gt = converter.decode(preds_index.data, preds_size.data)

                else:
                    print("Not implemented error")
                    sys.exit()
            else:
                preds_str_fake_img_c1_s1 = [':None:'] * fake_img_c1_s1.shape[0]
                preds_str_gt = [':None:'] * fake_img_c1_s1.shape[0] 

            os.makedirs(os.path.join(opt.trainDir,str(iteration)), exist_ok=True)
            for trImgCntr in range(opt.batch_size):
                try:
                    save_image(tensor2im(fake_img_c1_s1[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_c1_s1_'+labels_z_c_1[trImgCntr]+'_ocr:'+preds_str_fake_img_c1_s1[trImgCntr]+'.png'))
                    if not opt.zAlone:
                        save_image(tensor2im(fake_img_c2_s1[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_c2_s1_'+labels_z_c_2[trImgCntr]+'_ocr:'+preds_str_fake_img_c2_s1[trImgCntr]+'.png'))
                        save_image(tensor2im(fake_img_c1_s2[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_c1_s2_'+labels_z_c_1[trImgCntr]+'_ocr:'+preds_str_fake_img_c1_s2[trImgCntr]+'.png'))
                        if trImgCntr<gt_image_tensors.shape[0]:
                            save_image(tensor2im(gt_image_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_gt_act:'+labels_gt[trImgCntr]+'_ocr:'+preds_str_gt[trImgCntr]+'.png'))
                except:
                    print('Warning while saving training image')
            
            elapsed_time = time.time() - start_time
            # for log
            
            with open(os.path.join(opt.exp_dir,opt.exp_name,'log_train.txt'), 'a') as log:

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}]  \
                    Train Dis loss: {loss_avg_dis.val():0.5f}, Train Gen loss: {loss_avg_gen.val():0.5f},\
                    Train UnSup OCR loss: {loss_avg_ocr_unsup.val():0.5f}, Train Sup OCR loss: {loss_avg_ocr_sup.val():0.5f}, \
                    Train R1-val loss: {log_r1_val.val():0.5f}, Train avg-path-loss: {log_avg_path_loss_val.val():0.5f}, \
                    Train mean-path-length loss: {log_avg_mean_path_length_avg.val():0.5f}, Train ada-aug-p: {log_ada_aug_p.val():0.5f}, \
                    Elapsed_time: {elapsed_time:0.5f}'
                
                
                #plotting
                lib.plot.plot(os.path.join(opt.plotDir,'Train-Dis-Loss'), loss_avg_dis.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-Gen-Loss'), loss_avg_gen.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-UnSup-OCR-Loss'), loss_avg_ocr_unsup.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-Sup-OCR-Loss'), loss_avg_ocr_sup.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-r1_val'), log_r1_val.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-path_loss_val'), log_avg_path_loss_val.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-mean_path_length_avg'), log_avg_mean_path_length_avg.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-ada_aug_p'), log_ada_aug_p.val().item())

                
                print(loss_log)

                loss_avg_dis.reset()
                loss_avg_gen.reset()
                loss_avg_ocr_unsup.reset()
                loss_avg_ocr_sup.reset()
                log_r1_val.reset()
                log_avg_path_loss_val.reset()
                log_avg_mean_path_length_avg.reset()
                log_ada_aug_p.reset()
                

            lib.plot.flush()

        lib.plot.tick()

        # save model per 1e+5 iter.
        if (iteration) % 1e+4 == 0:
            torch.save({
                'cEncoder':cEncoder.state_dict(),
                'genModel':genModel.state_dict(),
                'g_ema':g_ema.state_dict(),
                'ocrModel':ocrModel.state_dict(),
                'disEncModel':disEncModel.state_dict(),
                'optimizer':optimizer.state_dict(),
                'ocr_optimizer':ocr_optimizer.state_dict(),
                'dis_optimizer':dis_optimizer.state_dict()}, 
                os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_synth.pth'))
            

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
        cntr+=1
Ejemplo n.º 5
0
def train(opt):
    lib.print_model_settings(locals().copy())

    """ dataset preparation """
    if not opt.data_filtering_off:
        print('Filtering the images containing characters which are not in opt.character')
        print('Filtering the images whose label is longer than opt.batch_max_length')
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    log = open(os.path.join(opt.exp_dir,opt.exp_name,'log_dataset.txt'), 'a')
    AlignCollate_valid = AlignPairCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)

    train_dataset, train_dataset_log = hierarchical_dataset(root=opt.train_data, opt=opt)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=opt.batch_size,
        sampler=data_sampler(train_dataset, shuffle=True, distributed=opt.distributed),
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True)
    log.write(train_dataset_log)
    print('-' * 80)

    valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=opt.batch_size,
        sampler=data_sampler(train_dataset, shuffle=False, distributed=opt.distributed),
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()

    if 'Attn' in opt.Prediction:
        converter = AttnLabelConverter(opt.character)
    else:
        converter = CTCLabelConverter(opt.character)
    
    opt.num_class = len(converter.character)

    
    # styleModel = StyleTensorEncoder(input_dim=opt.input_channel)
    # genModel = AdaIN_Tensor_WordGenerator(opt)
    # disModel = MsImageDisV2(opt)

    # styleModel = StyleLatentEncoder(input_dim=opt.input_channel, norm='none')
    # mixModel = Mixer(opt,nblk=3, dim=opt.latent)
    genModel = styleGANGen(opt.size, opt.latent, opt.n_mlp, opt.num_class, channel_multiplier=opt.channel_multiplier).to(device)
    disModel = styleGANDis(opt.size, channel_multiplier=opt.channel_multiplier, input_dim=opt.input_channel).to(device)
    g_ema = styleGANGen(opt.size, opt.latent, opt.n_mlp, opt.num_class, channel_multiplier=opt.channel_multiplier).to(device)
    ocrModel = ModelV1(opt).to(device)
    accumulate(g_ema, genModel, 0)

    # #  weight initialization
    # for currModel in [styleModel, mixModel]:
    #     for name, param in currModel.named_parameters():
    #         if 'localization_fc2' in name:
    #             print(f'Skip {name} as it is already initialized')
    #             continue
    #         try:
    #             if 'bias' in name:
    #                 init.constant_(param, 0.0)
    #             elif 'weight' in name:
    #                 init.kaiming_normal_(param)
    #         except Exception as e:  # for batchnorm.
    #             if 'weight' in name:
    #                 param.data.fill_(1)
    #             continue

    if opt.contentLoss == 'vis' or opt.contentLoss == 'seq':
        ocrCriterion = torch.nn.L1Loss()
    else:
        if 'CTC' in opt.Prediction:
            ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
        else:
            ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)  # ignore [GO] token = ignore index 0

    # vggRecCriterion = torch.nn.L1Loss()
    # vggModel = VGGPerceptualLossModel(models.vgg19(pretrained=True), vggRecCriterion)
    
    print('model input parameters', opt.imgH, opt.imgW, opt.input_channel, opt.output_channel,
          opt.hidden_size, opt.num_class, opt.batch_max_length)

    if opt.distributed:
        genModel = torch.nn.parallel.DistributedDataParallel(
            genModel,
            device_ids=[opt.local_rank],
            output_device=opt.local_rank,
            broadcast_buffers=False,
        )
        
        disModel = torch.nn.parallel.DistributedDataParallel(
            disModel,
            device_ids=[opt.local_rank],
            output_device=opt.local_rank,
            broadcast_buffers=False,
        )
        ocrModel = torch.nn.parallel.DistributedDataParallel(
            ocrModel,
            device_ids=[opt.local_rank],
            output_device=opt.local_rank,
            broadcast_buffers=False
        )
    
    # styleModel = torch.nn.DataParallel(styleModel).to(device)
    # styleModel.train()
    
    # mixModel = torch.nn.DataParallel(mixModel).to(device)
    # mixModel.train()
    
    # genModel = torch.nn.DataParallel(genModel).to(device)
    # g_ema = torch.nn.DataParallel(g_ema).to(device)
    genModel.train()
    g_ema.eval()

    # disModel = torch.nn.DataParallel(disModel).to(device)
    disModel.train()

    # vggModel = torch.nn.DataParallel(vggModel).to(device)
    # vggModel.eval()

    # ocrModel = torch.nn.DataParallel(ocrModel).to(device)
    # if opt.distributed:
    #     ocrModel.module.Transformation.eval()
    #     ocrModel.module.FeatureExtraction.eval()
    #     ocrModel.module.AdaptiveAvgPool.eval()
    #     # ocrModel.module.SequenceModeling.eval()
    #     ocrModel.module.Prediction.eval()
    # else:
    #     ocrModel.Transformation.eval()
    #     ocrModel.FeatureExtraction.eval()
    #     ocrModel.AdaptiveAvgPool.eval()
    #     # ocrModel.SequenceModeling.eval()
    #     ocrModel.Prediction.eval()
    ocrModel.eval()

    if opt.distributed:
        g_module = genModel.module
        d_module = disModel.module
    else:
        g_module = genModel
        d_module = disModel

    g_reg_ratio = opt.g_reg_every / (opt.g_reg_every + 1)
    d_reg_ratio = opt.d_reg_every / (opt.d_reg_every + 1)

    optimizer = optim.Adam(
        genModel.parameters(),
        lr=opt.lr * g_reg_ratio,
        betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
    )
    dis_optimizer = optim.Adam(
        disModel.parameters(),
        lr=opt.lr * d_reg_ratio,
        betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
    )

    ## Loading pre-trained files
    if opt.modelFolderFlag:
        if len(glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth")))>0:
            opt.saved_synth_model = glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth"))[-1]

    if opt.saved_ocr_model !='' and opt.saved_ocr_model !='None':
        if not opt.distributed:
            ocrModel = torch.nn.DataParallel(ocrModel)
        print(f'loading pretrained ocr model from {opt.saved_ocr_model}')
        checkpoint = torch.load(opt.saved_ocr_model)
        ocrModel.load_state_dict(checkpoint)
        #temporary fix
        if not opt.distributed:
            ocrModel = ocrModel.module
    
    if opt.saved_gen_model !='' and opt.saved_gen_model !='None':
        print(f'loading pretrained gen model from {opt.saved_gen_model}')
        checkpoint = torch.load(opt.saved_gen_model, map_location=lambda storage, loc: storage)
        genModel.module.load_state_dict(checkpoint['g'])
        g_ema.module.load_state_dict(checkpoint['g_ema'])

    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        print(f'loading pretrained synth model from {opt.saved_synth_model}')
        checkpoint = torch.load(opt.saved_synth_model)
        
        # styleModel.load_state_dict(checkpoint['styleModel'])
        # mixModel.load_state_dict(checkpoint['mixModel'])
        genModel.load_state_dict(checkpoint['genModel'])
        g_ema.load_state_dict(checkpoint['g_ema'])
        disModel.load_state_dict(checkpoint['disModel'])
        
        optimizer.load_state_dict(checkpoint["optimizer"])
        dis_optimizer.load_state_dict(checkpoint["dis_optimizer"])

    # if opt.imgReconLoss == 'l1':
    #     recCriterion = torch.nn.L1Loss()
    # elif opt.imgReconLoss == 'ssim':
    #     recCriterion = ssim
    # elif opt.imgReconLoss == 'ms-ssim':
    #     recCriterion = msssim
    

    # loss averager
    loss_avg = Averager()
    loss_avg_dis = Averager()
    loss_avg_gen = Averager()
    loss_avg_imgRecon = Averager()
    loss_avg_vgg_per = Averager()
    loss_avg_vgg_sty = Averager()
    loss_avg_ocr = Averager()

    log_r1_val = Averager()
    log_avg_path_loss_val = Averager()
    log_avg_mean_path_length_avg = Averager()
    log_ada_aug_p = Averager()

    """ final options """
    with open(os.path.join(opt.exp_dir,opt.exp_name,'opt.txt'), 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)

    """ start training """
    start_iter = 0
    
    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        try:
            start_iter = int(opt.saved_synth_model.split('_')[-2].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    
    #get schedulers
    scheduler = get_scheduler(optimizer,opt)
    dis_scheduler = get_scheduler(dis_optimizer,opt)

    start_time = time.time()
    iteration = start_iter
    cntr=0
    
    mean_path_length = 0
    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    accum = 0.5 ** (32 / (10 * 1000))
    ada_augment = torch.tensor([0.0, 0.0], device=device)
    ada_aug_p = opt.augment_p if opt.augment_p > 0 else 0.0
    ada_aug_step = opt.ada_target / opt.ada_length
    r_t_stat = 0

    sample_z = torch.randn(opt.n_sample, opt.latent, device=device)

    while(True):
        # print(cntr)
        # train part
       
        if opt.lr_policy !="None":
            scheduler.step()
            dis_scheduler.step()
        
        image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(train_loader).next()

        image_input_tensors = image_input_tensors.to(device)
        image_gt_tensors = image_gt_tensors.to(device)
        batch_size = image_input_tensors.size(0)

        requires_grad(genModel, False)
        # requires_grad(styleModel, False)
        # requires_grad(mixModel, False)
        requires_grad(disModel, True)

        text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length)
        text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length)
        
        
        #forward pass from style and word generator
        # style = styleModel(image_input_tensors).squeeze(2).squeeze(2)
        style = mixing_noise(opt.batch_size, opt.latent, opt.mixing, device)
        # scInput = mixModel(style,text_2)
        if 'CTC' in opt.Prediction:
            images_recon_2,_ = genModel(style, text_2, input_is_latent=opt.input_latent)
        else:
            images_recon_2,_ = genModel(style, text_2[:,1:-1], input_is_latent=opt.input_latent)
        
        #Domain discriminator: Dis update
        if opt.augment:
            image_gt_tensors_aug, _ = augment(image_gt_tensors, ada_aug_p)
            images_recon_2, _ = augment(images_recon_2, ada_aug_p)

        else:
            image_gt_tensors_aug = image_gt_tensors

        fake_pred = disModel(images_recon_2)
        real_pred = disModel(image_gt_tensors_aug)
        disCost = d_logistic_loss(real_pred, fake_pred)

        loss_dict["d"] = disCost*opt.disWeight
        loss_dict["real_score"] = real_pred.mean()
        loss_dict["fake_score"] = fake_pred.mean()

        loss_avg_dis.add(disCost)

        disModel.zero_grad()
        disCost.backward()
        dis_optimizer.step()

        if opt.augment and opt.augment_p == 0:
            ada_augment += torch.tensor(
                (torch.sign(real_pred).sum().item(), real_pred.shape[0]), device=device
            )
            ada_augment = reduce_sum(ada_augment)

            if ada_augment[1] > 255:
                pred_signs, n_pred = ada_augment.tolist()

                r_t_stat = pred_signs / n_pred

                if r_t_stat > opt.ada_target:
                    sign = 1

                else:
                    sign = -1

                ada_aug_p += sign * ada_aug_step * n_pred
                ada_aug_p = min(1, max(0, ada_aug_p))
                ada_augment.mul_(0)

        d_regularize = cntr % opt.d_reg_every == 0

        if d_regularize:
            image_gt_tensors.requires_grad = True
            image_input_tensors.requires_grad = True
            cat_tensor = image_gt_tensors
            real_pred = disModel(cat_tensor)
            
            r1_loss = d_r1_loss(real_pred, cat_tensor)

            disModel.zero_grad()
            (opt.r1 / 2 * r1_loss * opt.d_reg_every + 0 * real_pred[0]).backward()

            dis_optimizer.step()

        loss_dict["r1"] = r1_loss

        
        # #[Style Encoder] + [Word Generator] update
        image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(train_loader).next()
        
        image_input_tensors = image_input_tensors.to(device)
        image_gt_tensors = image_gt_tensors.to(device)
        batch_size = image_input_tensors.size(0)

        requires_grad(genModel, True)
        # requires_grad(styleModel, True)
        # requires_grad(mixModel, True)
        requires_grad(disModel, False)

        text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length)
        text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length)

        # style = styleModel(image_input_tensors).squeeze(2).squeeze(2)
        # scInput = mixModel(style,text_2)

        # images_recon_2,_ = genModel([scInput], input_is_latent=opt.input_latent)
        style = mixing_noise(batch_size, opt.latent, opt.mixing, device)
        
        if 'CTC' in opt.Prediction:
            images_recon_2, _ = genModel(style, text_2)
        else:
            images_recon_2, _ = genModel(style, text_2[:,1:-1])

        if opt.augment:
            images_recon_2, _ = augment(images_recon_2, ada_aug_p)

        fake_pred = disModel(images_recon_2)
        disGenCost = g_nonsaturating_loss(fake_pred)

        loss_dict["g"] = disGenCost

        # # #Adversarial loss
        # # disGenCost = disModel.module.calc_gen_loss(torch.cat((images_recon_2,image_input_tensors),dim=1))

        # #Input reconstruction loss
        # recCost = recCriterion(images_recon_2,image_gt_tensors)

        # #vgg loss
        # vggPerCost, vggStyleCost = vggModel(image_gt_tensors, images_recon_2)
        #ocr loss
        text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)
        length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
        if opt.contentLoss == 'vis' or opt.contentLoss == 'seq':
            preds_recon = ocrModel(images_recon_2, text_for_pred, is_train=False, returnFeat=opt.contentLoss)
            preds_gt = ocrModel(image_gt_tensors, text_for_pred, is_train=False, returnFeat=opt.contentLoss)
            ocrCost = ocrCriterion(preds_recon, preds_gt)
        else:
            if 'CTC' in opt.Prediction:
                
                preds_recon = ocrModel(images_recon_2, text_for_pred, is_train=False)
                # preds_o = preds_recon[:, :text_1.shape[1], :]
                preds_size = torch.IntTensor([preds_recon.size(1)] * batch_size)
                preds_recon_softmax = preds_recon.log_softmax(2).permute(1, 0, 2)
                ocrCost = ocrCriterion(preds_recon_softmax, text_2, preds_size, length_2)
                
                #predict ocr recognition on generated images
                # preds_recon_size = torch.IntTensor([preds_recon.size(1)] * batch_size)
                _, preds_recon_index = preds_recon.max(2)
                labels_o_ocr = converter.decode(preds_recon_index.data, preds_size.data)

                #predict ocr recognition on gt style images
                preds_s = ocrModel(image_input_tensors, text_for_pred, is_train=False)
                # preds_s = preds_s[:, :text_1.shape[1] - 1, :]
                preds_s_size = torch.IntTensor([preds_s.size(1)] * batch_size)
                _, preds_s_index = preds_s.max(2)
                labels_s_ocr = converter.decode(preds_s_index.data, preds_s_size.data)

                #predict ocr recognition on gt stylecontent images
                preds_sc = ocrModel(image_gt_tensors, text_for_pred, is_train=False)
                # preds_sc = preds_sc[:, :text_2.shape[1] - 1, :]
                preds_sc_size = torch.IntTensor([preds_sc.size(1)] * batch_size)
                _, preds_sc_index = preds_sc.max(2)
                labels_sc_ocr = converter.decode(preds_sc_index.data, preds_sc_size.data)

            else:
                preds_recon = ocrModel(images_recon_2, text_for_pred[:, :-1], is_train=False)  # align with Attention.forward
                target_2 = text_2[:, 1:]  # without [GO] Symbol
                ocrCost = ocrCriterion(preds_recon.view(-1, preds_recon.shape[-1]), target_2.contiguous().view(-1))

                #predict ocr recognition on generated images
                _, preds_o_index = preds_recon.max(2)
                labels_o_ocr = converter.decode(preds_o_index, length_for_pred)
                for idx, pred in enumerate(labels_o_ocr):
                    pred_EOS = pred.find('[s]')
                    labels_o_ocr[idx] = pred[:pred_EOS]  # prune after "end of sentence" token ([s])

                #predict ocr recognition on gt style images
                preds_s = ocrModel(image_input_tensors, text_for_pred, is_train=False)
                _, preds_s_index = preds_s.max(2)
                labels_s_ocr = converter.decode(preds_s_index, length_for_pred)
                for idx, pred in enumerate(labels_s_ocr):
                    pred_EOS = pred.find('[s]')
                    labels_s_ocr[idx] = pred[:pred_EOS]  # prune after "end of sentence" token ([s])
                
                #predict ocr recognition on gt stylecontent images
                preds_sc = ocrModel(image_gt_tensors, text_for_pred, is_train=False)
                _, preds_sc_index = preds_sc.max(2)
                labels_sc_ocr = converter.decode(preds_sc_index, length_for_pred)
                for idx, pred in enumerate(labels_sc_ocr):
                    pred_EOS = pred.find('[s]')
                    labels_sc_ocr[idx] = pred[:pred_EOS]  # prune after "end of sentence" token ([s])

        # cost =  opt.reconWeight*recCost + opt.disWeight*disGenCost + opt.vggPerWeight*vggPerCost + opt.vggStyWeight*vggStyleCost + opt.ocrWeight*ocrCost
        cost =  opt.disWeight*disGenCost + opt.ocrWeight*ocrCost

        # styleModel.zero_grad()
        genModel.zero_grad()
        # mixModel.zero_grad()
        disModel.zero_grad()
        # vggModel.zero_grad()
        ocrModel.zero_grad()
        
        cost.backward()
        optimizer.step()
        loss_avg.add(cost)

        g_regularize = cntr % opt.g_reg_every == 0

        if g_regularize:
            image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(train_loader).next()
        
            image_input_tensors = image_input_tensors.to(device)
            image_gt_tensors = image_gt_tensors.to(device)
            batch_size = image_input_tensors.size(0)

            text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length)
            text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length)

            path_batch_size = max(1, batch_size // opt.path_batch_shrink)

            # style = styleModel(image_input_tensors).squeeze(2).squeeze(2)
            # scInput = mixModel(style,text_2)

            # images_recon_2, latents = genModel([scInput],input_is_latent=opt.input_latent, return_latents=True)

            style = mixing_noise(path_batch_size, opt.latent, opt.mixing, device)
            
            
            if 'CTC' in opt.Prediction:
                images_recon_2, latents = genModel(style, text_2[:path_batch_size], return_latents=True)
            else:
                images_recon_2, latents = genModel(style, text_2[:path_batch_size,1:-1], return_latents=True)
            
            
            path_loss, mean_path_length, path_lengths = g_path_regularize(
                images_recon_2, latents, mean_path_length
            )

            genModel.zero_grad()
            weighted_path_loss = opt.path_regularize * opt.g_reg_every * path_loss

            if opt.path_batch_shrink:
                weighted_path_loss += 0 * images_recon_2[0, 0, 0, 0]

            weighted_path_loss.backward()

            optimizer.step()

            mean_path_length_avg = (
                reduce_sum(mean_path_length).item() / get_world_size()
            )

        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()


        #Individual losses
        loss_avg_gen.add(opt.disWeight*disGenCost)
        loss_avg_imgRecon.add(torch.tensor(0.0))
        loss_avg_vgg_per.add(torch.tensor(0.0))
        loss_avg_vgg_sty.add(torch.tensor(0.0))
        loss_avg_ocr.add(opt.ocrWeight*ocrCost)

        log_r1_val.add(loss_reduced["path"])
        log_avg_path_loss_val.add(loss_reduced["path"])
        log_avg_mean_path_length_avg.add(torch.tensor(mean_path_length_avg))
        log_ada_aug_p.add(torch.tensor(ada_aug_p))
        
        if get_rank() == 0:
            # pbar.set_description(
            #     (
            #         f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
            #         f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
            #         f"augment: {ada_aug_p:.4f}"
            #     )
            # )

            if wandb and opt.wandb:
                wandb.log(
                    {
                        "Generator": g_loss_val,
                        "Discriminator": d_loss_val,
                        "Augment": ada_aug_p,
                        "Rt": r_t_stat,
                        "R1": r1_val,
                        "Path Length Regularization": path_loss_val,
                        "Mean Path Length": mean_path_length,
                        "Real Score": real_score_val,
                        "Fake Score": fake_score_val,   
                        "Path Length": path_length_val,
                    }
                )
            # if cntr % 100 == 0:
            #     with torch.no_grad():
            #         g_ema.eval()
            #         sample, _ = g_ema([scInput[:,:opt.latent],scInput[:,opt.latent:]])
            #         utils.save_image(
            #             sample,
            #             os.path.join(opt.trainDir, f"sample_{str(cntr).zfill(6)}.png"),
            #             nrow=int(opt.n_sample ** 0.5),
            #             normalize=True,
            #             range=(-1, 1),
            #         )


        # validation part
        if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' 
            
            #Save training images
            curr_batch_size = style[0].shape[0]
            images_recon_2, _ = g_ema(style, text_2[:curr_batch_size], input_is_latent=opt.input_latent)
            
            os.makedirs(os.path.join(opt.trainDir,str(iteration)), exist_ok=True)
            for trImgCntr in range(batch_size):
                try:
                    if opt.contentLoss == 'vis' or opt.contentLoss == 'seq':
                        save_image(tensor2im(image_input_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_sInput_'+labels_1[trImgCntr]+'.png'))
                        save_image(tensor2im(image_gt_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csGT_'+labels_2[trImgCntr]+'.png'))
                        save_image(tensor2im(images_recon_2[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csRecon_'+labels_2[trImgCntr]+'.png'))
                    else:
                        save_image(tensor2im(image_input_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_sInput_'+labels_1[trImgCntr]+'_'+labels_s_ocr[trImgCntr]+'.png'))
                        save_image(tensor2im(image_gt_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csGT_'+labels_2[trImgCntr]+'_'+labels_sc_ocr[trImgCntr]+'.png'))
                        save_image(tensor2im(images_recon_2[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csRecon_'+labels_2[trImgCntr]+'_'+labels_o_ocr[trImgCntr]+'.png'))
                except:
                    print('Warning while saving training image')
            
            elapsed_time = time.time() - start_time
            # for log
            
            with open(os.path.join(opt.exp_dir,opt.exp_name,'log_train.txt'), 'a') as log:
                # styleModel.eval()
                genModel.eval()
                g_ema.eval()
                # mixModel.eval()
                disModel.eval()
                
                with torch.no_grad():                    
                    valid_loss, infer_time, length_of_data = validation_synth_v6(
                        iteration, g_ema, ocrModel, disModel, ocrCriterion, valid_loader, converter, opt)
                
                # styleModel.train()
                genModel.train()
                # mixModel.train()
                disModel.train()

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}] Train Synth loss: {loss_avg.val():0.5f}, \
                    Train Dis loss: {loss_avg_dis.val():0.5f}, Train Gen loss: {loss_avg_gen.val():0.5f},\
                    Train OCR loss: {loss_avg_ocr.val():0.5f}, \
                    Train R1-val loss: {log_r1_val.val():0.5f}, Train avg-path-loss: {log_avg_path_loss_val.val():0.5f}, \
                    Train mean-path-length loss: {log_avg_mean_path_length_avg.val():0.5f}, Train ada-aug-p: {log_ada_aug_p.val():0.5f}, \
                    Valid Synth loss: {valid_loss[0]:0.5f}, \
                    Valid Dis loss: {valid_loss[1]:0.5f}, Valid Gen loss: {valid_loss[2]:0.5f}, \
                    Valid OCR loss: {valid_loss[6]:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                
                
                #plotting
                lib.plot.plot(os.path.join(opt.plotDir,'Train-Synth-Loss'), loss_avg.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-Dis-Loss'), loss_avg_dis.val().item())
                
                lib.plot.plot(os.path.join(opt.plotDir,'Train-Gen-Loss'), loss_avg_gen.val().item())
                # lib.plot.plot(os.path.join(opt.plotDir,'Train-ImgRecon1-Loss'), loss_avg_imgRecon.val().item())
                # lib.plot.plot(os.path.join(opt.plotDir,'Train-VGG-Per-Loss'), loss_avg_vgg_per.val().item())
                # lib.plot.plot(os.path.join(opt.plotDir,'Train-VGG-Sty-Loss'), loss_avg_vgg_sty.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-OCR-Loss'), loss_avg_ocr.val().item())

                lib.plot.plot(os.path.join(opt.plotDir,'Train-r1_val'), log_r1_val.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-path_loss_val'), log_avg_path_loss_val.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-mean_path_length_avg'), log_avg_mean_path_length_avg.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-ada_aug_p'), log_ada_aug_p.val().item())

                lib.plot.plot(os.path.join(opt.plotDir,'Valid-Synth-Loss'), valid_loss[0].item())
                lib.plot.plot(os.path.join(opt.plotDir,'Valid-Dis-Loss'), valid_loss[1].item())

                lib.plot.plot(os.path.join(opt.plotDir,'Valid-Gen-Loss'), valid_loss[2].item())
                # lib.plot.plot(os.path.join(opt.plotDir,'Valid-ImgRecon1-Loss'), valid_loss[3].item())
                # lib.plot.plot(os.path.join(opt.plotDir,'Valid-VGG-Per-Loss'), valid_loss[4].item())
                # lib.plot.plot(os.path.join(opt.plotDir,'Valid-VGG-Sty-Loss'), valid_loss[5].item())
                lib.plot.plot(os.path.join(opt.plotDir,'Valid-OCR-Loss'), valid_loss[6].item())
                
                print(loss_log)

                loss_avg.reset()
                loss_avg_dis.reset()

                loss_avg_gen.reset()
                loss_avg_imgRecon.reset()
                loss_avg_vgg_per.reset()
                loss_avg_vgg_sty.reset()
                loss_avg_ocr.reset()

                log_r1_val.reset()
                log_avg_path_loss_val.reset()
                log_avg_mean_path_length_avg.reset()
                log_ada_aug_p.reset()
                

            lib.plot.flush()

        lib.plot.tick()

        # save model per 1e+5 iter.
        if (iteration) % 1e+4 == 0:
            torch.save({
                # 'styleModel':styleModel.state_dict(),
                # 'mixModel':mixModel.state_dict(),
                'genModel':g_module.state_dict(),
                'g_ema':g_ema.state_dict(),
                'disModel':d_module.state_dict(),
                'optimizer':optimizer.state_dict(),
                'dis_optimizer':dis_optimizer.state_dict()}, 
                os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_synth.pth'))
            

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
        cntr+=1
Ejemplo n.º 6
0
class OCR:
    def __init__(self):

        # model settings #
        self.path_model = 'model/TPS-ResNet-BiLSTM-Attn.pth'
        self.batch_size = 1
        self.batch_max_length = 25
        self.imgH = 32
        self.imgW = 100
        self.character = '0123456789abcdefghijklmnopqrstuvwxyz'
        self.Transformation = 'TPS'
        self.FeatureExtraction = 'ResNet'
        self.SequenceModeling = 'BiLSTM'
        self.Prediction = 'Attn'
        self.num_fiducial = 20
        self.input_channel = 1
        self.output_channel = 512
        self.hidden_size = 256
        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')
        parser = argparse.ArgumentParser()
        parser.add_argument('--rgb', action='store_true', help='use rgb input')
        self.opt = parser.parse_args()

        self.opt.num_gpu = torch.cuda.device_count()

        # load model
        self.converter = AttnLabelConverter(self.character)
        self.opt.num_class = len(self.converter.character)

        if self.opt.rgb:
            self.opt.input_channel = 3
        self.model = Model(self.opt)
        self.model = torch.nn.DataParallel(self.model).to('cuda:0')

        # load model
        self.model.load_state_dict(torch.load(self.path_model))

    def run(self, img):
        with torch.no_grad():
            img = cv2.normalize(img,
                                None,
                                alpha=0,
                                beta=1,
                                norm_type=cv2.NORM_MINMAX,
                                dtype=cv2.CV_32F)
            img = cv2.resize(img,
                             dsize=(100, 32),
                             interpolation=cv2.INTER_CUBIC)
            image_tensor = img[np.newaxis, np.newaxis, ...]

            image = torch.from_numpy(image_tensor).float().to(self.device)
            # For max length prediction
            length_for_pred = torch.IntTensor([self.batch_max_length] *
                                              self.batch_size).to(self.device)
            text_for_pred = torch.LongTensor(self.batch_size,
                                             self.batch_max_length +
                                             1).fill_(0).to(self.device)

            preds = self.model(image, text_for_pred, is_train=False)

            # select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            preds_str = self.converter.decode(preds_index, length_for_pred)

            pred = ''
            for i in range(len(preds_str)):
                pred += preds_str[i][:preds_str[i].find(
                    '[s]')]  # prune after "end of sentence" token ([s])
            print(pred)

            return pred
Ejemplo n.º 7
0
class Pytorch_model:
    def __init__(self, model_path: str, gpu_id=None):
        '''
        初始化pytorch模型
        :param model_path: 模型地址(可以是模型的参数或者参数和计算图一起保存的文件)
        :param alphabet: 字母表
        :param img_shape: 图像的尺寸(w,h)
        :param net: 网络计算图,如果在model_path中指定的是参数的保存路径,则需要给出网络的计算图
        :param img_channel: 图像的通道数: 1,3
        :param gpu_id: 在哪一块gpu上运行
        '''
        self.gpu_id = gpu_id
        checkpoint = torch.load(model_path)

        if self.gpu_id is not None and isinstance(
                self.gpu_id, int) and torch.cuda.is_available():
            self.device = torch.device("cuda:%s" % self.gpu_id)
        else:
            self.device = torch.device("cpu")
        print('device:', self.device)

        config = checkpoint['config']
        self.prediction_type = config['arch']['args']['prediction']['type']
        if self.prediction_type == 'CTC':
            self.converter = CTCLabelConverter(
                config['data_loader']['args']['alphabet'])
        else:
            self.converter = AttnLabelConverter(
                config['data_loader']['args']['alphabet'])
        num_class = len(self.converter.character)

        self.net = get_model(num_class, config)

        self.img_w = config['data_loader']['args']['dataset']['img_w']
        self.img_h = config['data_loader']['args']['dataset']['img_h']
        self.img_channel = config['data_loader']['args']['dataset'][
            'img_channel']

        self.net.load_state_dict(checkpoint['state_dict'])
        self.net.to(self.device)
        self.net.eval()

    def predict(self, img):
        '''
        对传入的图像进行预测,支持图像地址和numpy数组
        :param img: 像地址或numpy数组
        :param is_numpy:
        :return:
        '''
        assert self.img_channel in [1, 3], 'img_channel must in [1.3]'

        if isinstance(img, str):  # read image
            assert os.path.exists(img), 'file is not exists'
            img = cv2.imread(img, 0 if self.img_channel == 1 else 1)

        img = self.pre_processing(img)
        # 将图片由(w,h)变为(1,img_channel,h,w)
        img = transforms.ToTensor()(img)
        img = img.unsqueeze_(0)

        img = img.to(self.device)
        with torch.no_grad():
            text = torch.zeros(1, 80, dtype=torch.long, device=self.device)
            preds = self.net(img, text)
            preds = torch.softmax(preds, dim=2).detach().numpy()
        preds_str = self.converter.decode(preds)
        return preds_str

    def pre_processing(self, img):
        """
        对图片进行处理,先按照高度进行resize,resize之后如果宽度不足指定宽度,就补黑色像素,否则就强行缩放到指定宽度
        :param img_path: 图片
        :return:
        """
        img_h = self.img_h
        img_w = self.img_w
        h, w = img.shape[:2]
        ratio_h = float(img_h) / h
        new_w = int(w * ratio_h)
        img = cv2.resize(img, (new_w, img_h))
        return img
class TextExtractor():
    def __init__(self, image_folder, extract_text_file, split):
        self.i_folder = image_folder
        #print(image_folder)
        #print("aaaaaaa test")
        self.extract_text_file = extract_text_file
        self.canvas_size = 1280
        self.mag_ratio = 1.5
        self.show_time = False
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.cuda = torch.cuda.is_available()
        self.net = CRAFT()  #(1st model) model to detect words in images
        if self.cuda:
            self.net.load_state_dict(
                self.copyStateDict(
                    torch.load('CRAFT-pytorch/craft_mlt_25k.pth')))
        else:
            self.net.load_state_dict(
                self.copyStateDict(
                    torch.load('CRAFT-pytorch/craft_mlt_25k.pth',
                               map_location='cpu')))
        if self.cuda:
            self.net = self.net.cuda()
            self.net = torch.nn.DataParallel(self.net)
            cudnn.benchmark = False
        self.net.eval()
        self.refine_net = None

        self.text_threshold = 0.7
        self.link_threshold = 0.4
        self.low_text = 0.4
        self.poly = False

        self.result_folder = './' + split + '_' + 'intermediate_result/'

        if not os.path.isdir(self.result_folder):
            os.mkdir(self.result_folder)

        #Parameters for image to text model (2nd model)
        self.parser = argparse.ArgumentParser()
        #Data processing
        self.parser.add_argument('--batch_max_length',
                                 type=int,
                                 default=25,
                                 help='maximum-label-length')
        self.parser.add_argument('--imgH',
                                 type=int,
                                 default=32,
                                 help='the height of the input image')
        self.parser.add_argument('--imgW',
                                 type=int,
                                 default=100,
                                 help='the width of the input image')
        self.parser.add_argument('--rgb',
                                 default=False,
                                 action='store_true',
                                 help='use rgb input')
        self.parser.add_argument(
            '--character',
            type=str,
            default='0123456789abcdefghijklmnopqrstuvwxyz',
            help='character label')
        self.parser.add_argument('--sensitive',
                                 action='store_true',
                                 help='for sensitive character mode')
        self.parser.add_argument(
            '--PAD',
            action='store_true',
            help='whether to keep ratio then pad for image resize')
        #Model Architecture
        self.parser.add_argument('--Transformation',
                                 type=str,
                                 default='TPS',
                                 help='Transformation stage. None|TPS')
        self.parser.add_argument(
            '--FeatureExtraction',
            type=str,
            default='ResNet',
            help='FeatureExtraction stage. VGG|RCNN|ResNet')
        self.parser.add_argument('--SequenceModeling',
                                 type=str,
                                 default='BiLSTM',
                                 help='SequenceModeling stage. None|BiLSTM')
        self.parser.add_argument('--Prediction',
                                 type=str,
                                 default='Attn',
                                 help='Prediction stage. CTC|Attn')
        self.parser.add_argument('--num_fiducial',
                                 type=int,
                                 default=20,
                                 help='number of fiducial points of TPS-STN')
        self.parser.add_argument(
            '--input_channel',
            type=int,
            default=1,
            help='the number of input channel of Feature extractor')
        self.parser.add_argument(
            '--output_channel',
            type=int,
            default=512,
            help='the number of output channel of Feature extractor')
        self.parser.add_argument('--hidden_size',
                                 type=int,
                                 default=256,
                                 help='the size of the LSTM hidden state')
        #self.opt = self.parser.parse_args()
        self.opt, unknown = self.parser.parse_known_args()
        #self.opt, unknown = self.parser.parse_known_args()

        if 'CTC' in self.opt.Prediction:
            self.converter = CTCLabelConverter(self.opt.character)
        else:
            self.converter = AttnLabelConverter(self.opt.character)
        self.opt.num_class = len(self.converter.character)
        #print(opt.rgb)
        if self.opt.rgb:
            self.opt.input_channel = 3
        self.opt.num_gpu = torch.cuda.device_count()
        self.opt.batch_size = 192
        #self.opt.batch_size = 3
        self.opt.workers = 0
        self.model = Model(self.opt)  #image to text model (2nd model)
        self.model = torch.nn.DataParallel(self.model).to(self.device)
        self.model.load_state_dict(
            torch.load(
                'deep-text-recognition-benchmark/TPS-ResNet-BiLSTM-Attn.pth',
                map_location=self.device))
        self.model.eval()

    def copyStateDict(self, state_dict):
        if list(state_dict.keys())[0].startswith("module"):
            start_idx = 1
        else:
            start_idx = 0
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = ".".join(k.split(".")[start_idx:])
            new_state_dict[name] = v
        return new_state_dict

    def test_net(self,
                 net,
                 image,
                 text_threshold,
                 link_threshold,
                 low_text,
                 cuda,
                 poly,
                 refine_net=None):
        t0 = time.time()

        # resize
        img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
            image,
            self.canvas_size,
            interpolation=cv2.INTER_LINEAR,
            mag_ratio=self.mag_ratio)
        ratio_h = ratio_w = 1 / target_ratio

        # preprocessing
        x = imgproc.normalizeMeanVariance(img_resized)
        x = torch.from_numpy(x).permute(2, 0, 1)  # [h, w, c] to [c, h, w]
        x = Variable(x.unsqueeze(0))  # [c, h, w] to [b, c, h, w]
        if cuda:
            x = x.cuda()

        # forward pass
        with torch.no_grad():
            y, feature = net(x)

        # make score and link map
        score_text = y[0, :, :, 0].cpu().data.numpy()
        score_link = y[0, :, :, 1].cpu().data.numpy()

        # refine link
        if refine_net is not None:
            with torch.no_grad():
                y_refiner = refine_net(y, feature)
            score_link = y_refiner[0, :, :, 0].cpu().data.numpy()

        t0 = time.time() - t0
        t1 = time.time()

        # Post-processing
        boxes, polys = craft_utils.getDetBoxes(score_text, score_link,
                                               text_threshold, link_threshold,
                                               low_text, poly)

        # coordinate adjustment
        boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
        polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
        for k in range(len(polys)):
            if polys[k] is None: polys[k] = boxes[k]

        t1 = time.time() - t1

        # render results (optional)
        render_img = score_text.copy()
        render_img = np.hstack((render_img, score_link))
        ret_score_text = imgproc.cvt2HeatmapImg(render_img)

        if self.show_time:
            print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1))

        return boxes, polys, ret_score_text

    def extract_text(self):
        l = sorted(os.listdir(self.i_folder))
        img_to_index = {}
        count = 0
        for full_file in l:
            split_file = full_file.split(".")
            filename = split_file[0]
            img_to_index[count] = filename
            #print(count, filename)
            count += 1
            #print(filename)
            file_extension = "." + split_file[1]
            #print(filename, file_extension)
            image = imgproc.loadImage(self.i_folder + full_file)
            bboxes, polys, score_text = self.test_net(
                self.net, image, self.text_threshold, self.link_threshold,
                self.low_text, self.cuda, self.poly, self.refine_net)
            img = cv2.imread(self.i_folder + filename + file_extension)
            rgb_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            points = []
            order = []
            for i in range(0, len(bboxes)):
                sample_bbox = bboxes[i]
                min_point = sample_bbox[0]
                max_point = sample_bbox[2]
                for j, p in enumerate(sample_bbox):
                    if (p[0] <= min_point[0]):
                        min_point = (p[0], min_point[1])
                    if (p[1] <= min_point[1]):
                        min_point = (min_point[0], p[1])
                    if (p[0] >= max_point[0]):
                        max_point = (p[0], max_point[1])
                    if (p[1] >= max_point[1]):
                        max_point = (max_point[0], p[1])
                min_point = (max(min(len(rgb_img[0]), min_point[0]),
                                 0), max(min(len(rgb_img), min_point[1]), 0))
                max_point = (max(min(len(rgb_img[0]), max_point[0]),
                                 0), max(min(len(rgb_img), max_point[1]), 0))
                points.append((min_point, max_point))
                order.append(0)
            num_ordered = 0
            rows_ordered = 0
            points_sorted = []
            ordered_points_index = 0
            order_sorted = []
            while (num_ordered < len(points)):
                #find lowest-y that is unordered
                min_y = len(rgb_img)
                min_y_index = -1
                for i in range(0, len(points)):
                    if (order[i] == 0):
                        if (points[i][0][1] <= min_y):
                            min_y = points[i][0][1]
                            min_y_index = i
                rows_ordered += 1
                order[min_y_index] = rows_ordered
                num_ordered += 1
                points_sorted.append(points[min_y_index])
                order_sorted.append(rows_ordered)
                ordered_points_index = len(points_sorted) - 1

                # Group bboxes that are on the same row
                max_y = points[min_y_index][1][1]
                range_y = max_y - min_y
                for i in range(0, len(points)):
                    if (order[i] == 0):
                        min_y_i = points[i][0][1]
                        max_y_i = points[i][1][1]
                        range_y_i = max_y_i - min_y_i
                        if (max_y_i >= min_y and min_y_i <= max_y):
                            overlap = (min(max_y_i, max_y) -
                                       max(min_y_i, min_y)) / (max(
                                           1, min(range_y, range_y_i)))
                            if (overlap >= 0.30):
                                order[i] = rows_ordered
                                num_ordered += 1
                                min_x_i = points[i][0][0]
                                for j in range(ordered_points_index,
                                               len(points_sorted) + 1):
                                    if (j < len(points_sorted)
                                        ):  #insert before
                                        min_x_j = points_sorted[j][0][0]
                                        if (min_x_i < min_x_j):
                                            points_sorted.insert(j, points[i])
                                            order_sorted.insert(
                                                j, rows_ordered)
                                            break
                                    else:  #insert at the end of array
                                        points_sorted.insert(j, points[i])
                                        order_sorted.insert(j, rows_ordered)
                                        break
            for i in range(0, len(points_sorted)):
                min_point = points_sorted[i][0]
                max_point = points_sorted[i][1]
                mask_file = self.result_folder + filename + "_" + str(
                    order_sorted[i]) + "_" + str(i) + file_extension
                crop_image = rgb_img[int(min_point[1]):int(max_point[1]),
                                     int(min_point[0]):int(max_point[0])]
                #print(filename, min_point, max_point, len(rgb_img), len(rgb_img[0]))
                cv2.imwrite(mask_file, crop_image)
        AlignCollate_demo = AlignCollate(imgH=self.opt.imgH,
                                         imgW=self.opt.imgW,
                                         keep_ratio_with_pad=self.opt.PAD)
        demo_data = RawDataset(root=self.result_folder,
                               opt=self.opt)  # use RawDataset
        demo_loader = torch.utils.data.DataLoader(
            demo_data,
            batch_size=self.opt.batch_size,
            shuffle=False,
            num_workers=int(self.opt.workers),
            collate_fn=AlignCollate_demo,
            pin_memory=True)
        f = open(self.extract_text_file, "w")
        count = -1
        curr_order = 1
        curr_filename = ""
        output_string = ""
        end_line = "[SEP] "
        with torch.no_grad():
            for image_tensors, image_path_list in demo_loader:
                batch_size = image_tensors.size(0)
                image = image_tensors.to(self.device)
                #image = (torch.from_numpy(crop_image).unsqueeze(0)).to(device)
                #print(image_path_list)
                #print(image.size())
                length_for_pred = torch.IntTensor([self.opt.batch_max_length] *
                                                  batch_size).to(self.device)
                text_for_pred = torch.LongTensor(batch_size,
                                                 self.opt.batch_max_length +
                                                 1).fill_(0).to(self.device)
                preds = self.model(image, text_for_pred, is_train=False)
                _, preds_index = preds.max(2)
                preds_str = self.converter.decode(preds_index, length_for_pred)
                for path, p in zip(image_path_list, preds_str):
                    #print(path)
                    if 'Attn' in self.opt.Prediction:
                        pred_EOS = p.find('[s]')
                        p = p[:
                              pred_EOS]  # prune after "end of sentence" token ([s])
                    path_info = path[len(self.result_folder):].split(
                        ".")[0].split(
                            "_"
                        )  #ASSUMES FILE EXTENSION OF SIZE 4 (.PNG, .JPG, ETC)
                    #print(curr_filename)
                    #print(path_info[0])
                    #print("PATHINFO: ",path_info[0])
                    if (not (curr_filename == path_info[0])):
                        if (not (curr_filename == "")):
                            f.write(str(count) + "\n")
                            f.write(curr_filename + "\n")
                            f.write(output_string + "\n\n")
                        count += 1
                        curr_filename = img_to_index[count]  #path_info[0]
                        #print("CURRFILE: ", curr_filename)
                        while (not (curr_filename == path_info[0])):
                            f.write(str(count) + "\n")
                            f.write(curr_filename + "\n")
                            f.write("\n\n")
                            count += 1
                            curr_filename = img_to_index[count]  #path_info[0]
                            #print("CURRFILE: ", curr_filename)
                        output_string = ""
                        curr_order = 1
                    if (int(path_info[1]) > curr_order):
                        curr_order += 1
                        output_string += end_line
                    output_string += p + " "
            f.write(str(count) + "\n")
            f.write(curr_filename + "\n")
            f.write(output_string + "\n\n")
        f.close()

        #Go through each image in the i_folder and crop out text

        #generate text and write to text file

    def get_item(self, index):
        f = open(self.extract_text_file, "r")
        Lines = f.readlines()
        return (Lines[4 * index + 2][:-1])
        # read text file


#TEST
#t_e = TextExtractor("data/mmimdb-256/dataset-resized-256max/dev_n/images/","text_extract_output.txt")
#t_e.extract_text()
#text = t_e.get_item(1)
#print(text)
Ejemplo n.º 9
0
def demo(opt):
    """ model configuration """
    converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    # print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
    #       opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
    #       opt.SequenceModeling, opt.Prediction)
    model = torch.nn.DataParallel(model).to(device)

    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model, map_location=device))

    # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
    AlignCollate_demo = AlignCollate(imgH=opt.imgH,
                                     imgW=opt.imgW,
                                     keep_ratio_with_pad=opt.PAD)
    demo_data = RawDataset(root=opt.image_folder, opt=opt)  # use RawDataset
    demo_loader = torch.utils.data.DataLoader(demo_data,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=int(opt.workers),
                                              collate_fn=AlignCollate_demo,
                                              pin_memory=True)

    # predict
    model.eval()

    with torch.no_grad():
        for image_tensors, image_path_list in demo_loader:
            all_pred_strs = []
            all_confidence_scores = []
            batch_size = image_tensors.size(0)
            image = image_tensors.to(device)
            # For max length prediction
            length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                              batch_size).to(device)
            text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                             1).fill_(0).to(device)

            predss = model(image, text_for_pred, is_train=False)[0]

            for i, preds in enumerate(predss):
                confidence_score_list = []
                pred_str_list = []

                # select max probability (greedy decoding) then decode index to character
                _, preds_index = preds.max(2)
                preds_str = converter.decode(preds_index, length_for_pred)

                preds_prob = F.softmax(preds, dim=2)
                preds_max_prob, _ = preds_prob.max(dim=2)
                for pred, pred_max_prob in zip(preds_str, preds_max_prob):
                    pred_EOS = pred.find('[s]')
                    pred = pred[:
                                pred_EOS]  # prune after "end of sentence" token ([s])
                    pred_str_list.append(pred)
                    pred_max_prob = pred_max_prob[:pred_EOS]

                    # calculate confidence score (= multiply of pred_max_prob)
                    try:
                        confidence_score = pred_max_prob.cumprod(
                            dim=0)[-1].cpu().numpy()
                    except:
                        confidence_score = 0  # for empty pred case, when prune after "end of sentence" token ([s])
                    confidence_score_list.append(confidence_score)

                all_pred_strs.append(pred_str_list)
                all_confidence_scores.append(confidence_score_list)

            all_confidence_scores = np.array(all_confidence_scores)
            all_pred_strs = np.array(all_pred_strs)

            best_pred_index = np.argmax(all_confidence_scores, axis=0)
            best_pred_index = np.expand_dims(best_pred_index, axis=0)

            # Get max predition per image through blocks
            all_pred_strs = np.take_along_axis(all_pred_strs,
                                               best_pred_index,
                                               axis=0)[0]
            all_confidence_scores = np.take_along_axis(all_confidence_scores,
                                                       best_pred_index,
                                                       axis=0)[0]

            log = open(f'./log_demo_result.txt', 'w')
            dashed_line = '-' * 80
            head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score'

            print(f'{dashed_line}\n{head}\n{dashed_line}')
            log.write(f'{dashed_line}\n{head}\n{dashed_line}\n')
            for img_name, pred, confidence_score in zip(
                    image_path_list, all_pred_strs, all_confidence_scores):
                print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}')
                log.write(
                    f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n')

            log.close()
Ejemplo n.º 10
0
class ModelWrapper(Model):
    def __init__(self, opt):
        if 'CTC' in opt.Prediction:
            self.converter = CTCLabelConverter(opt.character)
        else:
            self.converter = AttnLabelConverter(opt.character)
        opt.num_class = len(self.converter.character)
        if opt.rgb:
            opt.input_channel = 3
        super().__init__(opt)
        self = torch.nn.DataParallel(self).to(device)
        print('loading pretrained model from %s' % opt.saved_model)
        self.load_state_dict(torch.load(opt.saved_model, map_location=device))
        self.opt = opt

    def predict(self, img):
        self.eval()
        batch_size = 1
        with torch.no_grad():
            AlignCollate_demo = AlignCollate(imgH=self.opt.imgH,
                                             imgW=self.opt.imgW,
                                             keep_ratio_with_pad=self.opt.PAD)
            transform_PIL = transforms.ToPILImage()

            image = [transform_PIL(img)]

            image = AlignCollate_demo((image, ""))[0]
            length_for_pred = torch.IntTensor([self.opt.batch_max_length] *
                                              batch_size).to(device)
            text_for_pred = torch.LongTensor(
                batch_size, self.opt.batch_max_length + 1).fill_(0).to(device)
            print(image.shape)
            cv2.imshow("", tensor2im(image[0]))

            if 'CTC' in self.opt.Prediction:
                preds = self(image, text_for_pred).log_softmax(2)

                # Select max probabilty (greedy decoding) then decode index to character
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                _, preds_index = preds.max(2)
                preds_index = preds_index.view(-1)
                preds_str = self.converter.decode(preds_index.data,
                                                  preds_size.data)

            else:
                preds = self(image, text_for_pred, is_train=False)

                # select max probabilty (greedy decoding) then decode index to character
                _, preds_index = preds.max(2)
                preds_str = self.converter.decode(preds_index, length_for_pred)
            preds_prob = F.softmax(preds, dim=2)
            preds_max_prob, _ = preds_prob.max(dim=2)
            pred_max_prob = preds_max_prob[0]
            pred = preds_str[0]
            if 'Attn' in self.opt.Prediction:
                pred_EOS = pred.find('[s]')
                pred = pred[:
                            pred_EOS]  # prune after "end of sentence" token ([s])
                pred_max_prob = pred_max_prob[:pred_EOS]
            confidence_score = pred_max_prob.cumprod(dim=0)[-1]
            return pred, confidence_score
class TextReader(object):

    opts = SimpleNamespace()
    trmodel = None
    device = None
    convertor = None

    def __init__(self, args):

        path = os.path.abspath(__file__)
        dir_path = os.path.dirname(path)

        self.opts.workers = args.get("workers",
                                     5)  # number of data loading workers
        ## data processing args
        self.opts.batch_size = args.get('batch_size',
                                        50)  #maximum-label-length
        self.opts.batch_max_length = args.get('batch_max_length',
                                              25)  #maximum-label-length
        self.opts.imgH = args.get('imgH', 32)  # the height of the input image
        self.opts.imgW = args.get('imgW', 100)  # #the width of the input image
        self.opts.rgb = args.get('rgb', False)  # use rgb input
        self.opts.character = args.get(
            'character',
            '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~'
        )
        #self.opts.character = args.get('character','0123456789abcdefghijklmnopqrstuvwxyz') #character label
        self.opts.sensitive = args.get('sensitive',
                                       True)  #for sensitive character mode
        self.opts.PAD = args.get(
            'PAD', True)  #whether to keep ratio then pad for image resize
        ## Model architecture
        self.opts.Transformation = args.get(
            'Transformation', 'TPS')  #Transformation stage. None|TPS
        self.opts.FeatureExtraction = args.get(
            'FeatureExtraction',
            'ResNet')  #FeatureExtraction stage. VGG|RCNN|ResNet
        self.opts.SequenceModeling = args.get(
            'SequenceModeling', 'BiLSTM')  #SequenceModeling stage. None|BiLSTM
        self.opts.Prediction = args.get('Prediction',
                                        'Attn')  #Prediction stage. CTC|Attn
        self.opts.num_fiducial = args.get(
            'num_fiducial', 20)  #number of fiducial points of TPS-STN
        self.opts.input_channel = args.get(
            'input_channel',
            1)  #the number of input channel of Feature extractor
        self.opts.output_channel = args.get(
            'output_channel',
            512)  #the number of output channel of Feature extractor
        self.opts.hidden_size = args.get(
            'hidden_size', 256)  #the size of the LSTM hidden state
        self.opts.num_gpu = 0 if torch.cuda.is_available(
        ) else torch.cuda.device_count()
        self.opts.saved_model = args.get(
            'saved_model',
            dir_path + "/models/TPS-ResNet-BiLSTM-Attn-case-sensitive.pth")

        if self.opts.sensitive:
            self.opts.character = string.printable[:
                                                   -6]  # same with ASTER setting (use 94 char).

        if 'CTC' in self.opts.Prediction:
            self.converter = CTCLabelConverter(self.opts.character)
        else:
            self.converter = AttnLabelConverter(self.opts.character)
        self.opts.num_class = len(self.converter.character)

        if self.opts.rgb:
            self.opts.input_channel = 3

        self.pre_load_model()

    #
    # private function to preload the torch model
    #
    def pre_load_model(self):

        opt = self.opts
        print("preloading the model with opts " + str(opt))

        self.trmodel = Model(opt)

        self.trmodel = torch.nn.DataParallel(self.trmodel)
        if torch.cuda.is_available():
            self.trmodel = self.trmodel.cuda()
            self.device = torch.device('cuda:0')
        else:
            self.device = torch.device('cpu')

        # load model
        print('loading pretrained model from %s' % self.opts.saved_model)
        if torch.cuda.is_available():
            self.trmodel.load_state_dict(torch.load(self.opts.saved_model))
        else:
            self.trmodel.load_state_dict(
                torch.load(opt.saved_model, map_location='cpu'))

        self.trmodel.eval()

    #
    #
    #
    def predictAllImagesInFolder(self, src_path):

        opt = self.opts
        AlignCollate_demo = AlignCollate(imgH=opt.imgH,
                                         imgW=opt.imgW,
                                         keep_ratio_with_pad=opt.PAD)
        demo_data = RawDataset(root=src_path, opt=opt)  # use RawDataset
        demo_loader = torch.utils.data.DataLoader(
            demo_data,
            batch_size=opt.batch_size,
            shuffle=False,
            num_workers=int(opt.workers),
            collate_fn=AlignCollate_demo,
            pin_memory=torch.cuda.is_available())

        results = []
        for image_tensors, image_path_list in demo_loader:

            preds_str = self.predict(image_tensors)

            for img_name, pred in zip(image_path_list, preds_str):
                if 'Attn' in opt.Prediction:
                    pred = pred[:pred.find(
                        '[s]')]  # prune after "end of sentence" token ([s])
                results.append(f'{os.path.basename(img_name)},{pred}')

        return results

    ##
    ##
    ##
    def predict(self, image_tensors):

        print("############# About to predict for next batch ****")
        batch_size = image_tensors.size(0)
        with torch.no_grad():
            image = image_tensors.to(self.device)
            length_for_pred = torch.IntTensor([self.opts.batch_max_length] *
                                              batch_size)
            text_for_pred = torch.LongTensor(
                batch_size, self.opts.batch_max_length + 1).fill_(0)

        if 'CTC' in self.opts.Prediction:
            preds = self.trmodel(image, text_for_pred).log_softmax(2)

            # Select max probabilty (greedy decoding) then decode index to character
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            _, preds_index = preds.permute(1, 0, 2).max(2)
            preds_index = preds_index.transpose(1, 0).contiguous().view(-1)
            preds_str = self.converter.decode(preds_index.data,
                                              preds_size.data)

        else:
            preds = self.trmodel(image, text_for_pred, is_train=False)

            # select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            preds_str = self.converter.decode(preds_index, length_for_pred)

        # print('-' * 80)
        # print('image_path\tpredicted_labels')
        # print('-' * 80)
        # pred = ""
        # for pred in preds_str:
        # if 'Attn' in self.opts.Prediction:
        # pred = pred[:pred.find('[s]')]  # prune after "end of sentence" token ([s])
        # print("predictions = "+pred)

        return preds_str
Ejemplo n.º 12
0
class TextRecongtion():
    def __init__(self):
        self.opt = Recognition_Option()
        # self.opt = parser.parse_args()
        if self.opt.sensitive:
            self.opt.character = string.printable[:
                                                  -6]  # same with ASTER setting (use 94 char).
        cudnn.benchmark = False
        cudnn.deterministic = True
        self.opt.num_gpu = torch.cuda.device_count()

    def load_net(self):
        """ model configuration """
        if 'CTC' in self.opt.Prediction:
            self.converter = CTCLabelConverter(self.opt.character)
        else:
            self.converter = AttnLabelConverter(self.opt.character)
        self.opt.num_class = len(self.converter.character)

        if self.opt.rgb:
            self.opt.input_channel = 3
        print("input channle :{}".format(self.opt.input_channel))
        self.model = Model(self.opt)
        print('model input parameters', self.opt.imgH, self.opt.imgW,
              self.opt.num_fiducial, self.opt.input_channel,
              self.opt.output_channel, self.opt.hidden_size,
              self.opt.num_class, self.opt.batch_max_length,
              self.opt.Transformation, self.opt.FeatureExtraction,
              self.opt.SequenceModeling, self.opt.Prediction)
        self.model = torch.nn.DataParallel(self.model).to(device)

        # load model
        print('loading pretrained model from %s' % self.opt.saved_model)
        state_dict = torch.load(self.opt.saved_model, map_location=device)
        self.model.load_state_dict(state_dict)

        # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
        self.AlignCollate_demo = AlignCollate(imgH=self.opt.imgH,
                                              imgW=self.opt.imgW,
                                              keep_ratio_with_pad=self.opt.PAD)

    def predict(self, image_list):
        if len(image_list) <= 0:
            return [""]
        demo_data = RawDataset(image_list, opt=self.opt)  # use RawDataset
        demo_loader = torch.utils.data.DataLoader(
            demo_data,
            batch_size=self.opt.batch_size,
            shuffle=False,
            num_workers=int(self.opt.workers),
            collate_fn=self.AlignCollate_demo,
            pin_memory=True)

        # predict
        ret = []
        self.model.eval()
        with torch.no_grad():
            for image_tensors, image_path_list in demo_loader:
                batch_size = image_tensors.size(0)
                image = image_tensors.to(device)
                # For max length prediction
                length_for_pred = torch.IntTensor([self.opt.batch_max_length] *
                                                  batch_size).to(device)
                text_for_pred = torch.LongTensor(batch_size,
                                                 self.opt.batch_max_length +
                                                 1).fill_(0).to(device)

                if 'CTC' in self.opt.Prediction:
                    preds = self.model(image, text_for_pred)

                    # Select max probabilty (greedy decoding) then decode index to character
                    preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                    _, preds_index = preds.max(2)
                    preds_index = preds_index.view(-1)
                    preds_str = self.converter.decode(preds_index.data,
                                                      preds_size.data)

                else:
                    preds = self.model(image, text_for_pred, is_train=False)

                    # select max probabilty (greedy decoding) then decode index to character
                    _, preds_index = preds.max(2)
                    preds_str = self.converter.decode(preds_index,
                                                      length_for_pred)

                # log = open(f'./log_demo_result.txt', 'a')
                # dashed_line = '-' * 80
                # head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score'

                # print(f'{dashed_line}\n{head}\n{dashed_line}')
                # log.write(f'{dashed_line}\n{head}\n{dashed_line}\n')

                preds_prob = F.softmax(preds, dim=2)
                preds_max_prob, _ = preds_prob.max(dim=2)
                for img_name, pred, pred_max_prob in zip(
                        image_path_list, preds_str, preds_max_prob):
                    if 'Attn' in self.opt.Prediction:
                        pred_EOS = pred.find('[s]')
                        pred = pred[:
                                    pred_EOS]  # prune after "end of sentence" token ([s])
                        pred_max_prob = pred_max_prob[:pred_EOS]

                    # calculate confidence score (= multiply of pred_max_prob)
                    confidence_score = pred_max_prob.cumprod(dim=0)[-1]
                    if confidence_score <= 0.4:
                        pred = "None"
                    ret.append(pred)
                    # ret.append(preds_str)
                    print(
                        f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}')
        return ret
Ejemplo n.º 13
0
def test(opt):
    lib.print_model_settings(locals().copy())

    if 'Attn' in opt.Prediction:
        converter = AttnLabelConverter(opt.character)
        text_len = opt.batch_max_length + 2
    else:
        converter = CTCLabelConverter(opt.character)
        text_len = opt.batch_max_length

    opt.classes = converter.character
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)

    valid_dataset = LmdbDataset(root=opt.test_data, opt=opt)
    test_data_sampler = data_sampler(valid_dataset,
                                     shuffle=False,
                                     distributed=False)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        False,  # 'True' to check training progress with validation function.
        sampler=test_data_sampler,
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True,
        drop_last=False)

    print('-' * 80)

    opt.num_class = len(converter.character)

    ocrModel = ModelV1(opt).to(device)

    ## Loading pre-trained files
    print(f'loading pretrained ocr model from {opt.saved_ocr_model}')
    checkpoint = torch.load(opt.saved_ocr_model,
                            map_location=lambda storage, loc: storage)
    ocrModel.load_state_dict(checkpoint)

    evalCntr = 0
    fCntr = 0

    c1_s1_input_correct = 0.0
    c1_s1_input_ed_correct = 0.0
    # pdb.set_trace()

    for vCntr, (image_input_tensors, labels_gt) in enumerate(valid_loader):
        print(vCntr)

        image_input_tensors = image_input_tensors.to(device)
        text_gt, length_gt = converter.encode(
            labels_gt, batch_max_length=opt.batch_max_length)

        with torch.no_grad():
            currBatchSize = image_input_tensors.shape[0]
            # text_for_pred = torch.LongTensor(opt.batch_size, opt.batch_max_length + 1).fill_(0).to(device)
            length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                              currBatchSize).to(device)
            #Run OCR prediction
            if 'CTC' in opt.Prediction:
                preds = ocrModel(image_input_tensors, text_gt, is_train=False)
                preds_size = torch.IntTensor([preds.size(1)] *
                                             image_input_tensors.shape[0])
                _, preds_index = preds.max(2)
                preds_str_gt_1 = converter.decode(preds_index.data,
                                                  preds_size.data)

            else:
                preds = ocrModel(
                    image_input_tensors, text_gt[:, :-1],
                    is_train=False)  # align with Attention.forward
                _, preds_index = preds.max(2)
                preds_str_gt_1 = converter.decode(preds_index, length_for_pred)
                for idx, pred in enumerate(preds_str_gt_1):
                    pred_EOS = pred.find('[s]')
                    preds_str_gt_1[
                        idx] = pred[:
                                    pred_EOS]  # prune after "end of sentence" token ([s])

        for trImgCntr in range(image_input_tensors.shape[0]):
            #ocr accuracy
            # for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob):
            c1_s1_input_gt = labels_gt[trImgCntr]
            c1_s1_input_ocr = preds_str_gt_1[trImgCntr]

            if c1_s1_input_gt == c1_s1_input_ocr:
                c1_s1_input_correct += 1

            # ICDAR2019 Normalized Edit Distance
            if len(c1_s1_input_gt) == 0 or len(c1_s1_input_ocr) == 0:
                c1_s1_input_ed_correct += 0
            elif len(c1_s1_input_gt) > len(c1_s1_input_ocr):
                c1_s1_input_ed_correct += 1 - edit_distance(
                    c1_s1_input_ocr, c1_s1_input_gt) / len(c1_s1_input_gt)
            else:
                c1_s1_input_ed_correct += 1 - edit_distance(
                    c1_s1_input_ocr, c1_s1_input_gt) / len(c1_s1_input_ocr)

            evalCntr += 1

            fCntr += 1

    avg_c1_s1_input_wer = c1_s1_input_correct / float(evalCntr)
    avg_c1_s1_input_cer = c1_s1_input_ed_correct / float(evalCntr)

    # if not(opt.realVaData):
    with open(os.path.join(opt.exp_dir, opt.exp_name, 'log_test.txt'),
              'a') as log:
        # training loss and validation loss

        loss_log = f'Word Acc: {avg_c1_s1_input_wer:0.5f}, Test Input Char Acc: {avg_c1_s1_input_cer:0.5f}'

        print(loss_log)
        log.write(loss_log + "\n")
Ejemplo n.º 14
0
def test(opt):
    lib.print_model_settings(locals().copy())

    if 'Attn' in opt.Prediction:
        converter = AttnLabelConverter(opt.character)
        text_len = opt.batch_max_length+2
    else:
        converter = CTCLabelConverter(opt.character)
        text_len = opt.batch_max_length

    opt.classes = converter.character
    
    """ dataset preparation """
    if not opt.data_filtering_off:
        print('Filtering the images containing characters which are not in opt.character')
        print('Filtering the images whose label is longer than opt.batch_max_length')
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    # AlignCollate_valid = AlignPairImgCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
    AlignCollate_valid = AlignPairImgCollate_Test(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
    
    valid_dataset = LmdbTestStyleContentDataset(root=opt.test_data, opt=opt, dataMode=opt.realVaData)
    test_data_sampler = data_sampler(valid_dataset, shuffle=False, distributed=False)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=opt.batch_size, 
        shuffle=False,  # 'True' to check training progress with validation function.
        sampler=test_data_sampler,
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True, drop_last=False)
    
    print('-' * 80)

    AlignCollate_text = AlignSynthTextCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
    text_dataset = text_gen_synth(opt)
    text_data_sampler = data_sampler(text_dataset, shuffle=True, distributed=False)
    text_loader = torch.utils.data.DataLoader(
        text_dataset, batch_size=opt.batch_size,
        shuffle=False,
        sampler=text_data_sampler,
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_text,
        drop_last=True)
    opt.num_class = len(converter.character)

 
    text_loader = sample_data(text_loader, text_data_sampler, False)

    c_code_size = opt.latent
    if opt.cEncode == 'mlp':
        cEncoder = GlobalContentEncoder(opt.num_class, text_len, opt.char_embed_size, c_code_size).to(device)
    elif opt.cEncode == 'cnn':
        if opt.contentNorm == 'in':
            cEncoder = ResNet_StyleExtractor_WIN(1, opt.latent).to(device)
        else:
            cEncoder = ResNet_StyleExtractor(1, opt.latent).to(device)
    if opt.styleNorm == 'in':
        styleModel = ResNet_StyleExtractor_WIN(opt.input_channel, opt.latent).to(device)
    else:
        styleModel = ResNet_StyleExtractor(opt.input_channel, opt.latent).to(device)
    
    ocrModel = ModelV1(opt).to(device)

    g_ema = styleGANGen(opt.size, opt.latent, opt.latent, opt.n_mlp, content_dim=c_code_size, channel_multiplier=opt.channel_multiplier).to(device)
    g_ema.eval()
    
    bestModelError=1e5

    ## Loading pre-trained files
    print(f'loading pretrained ocr model from {opt.saved_ocr_model}')
    checkpoint = torch.load(opt.saved_ocr_model, map_location=lambda storage, loc: storage)
    ocrModel.load_state_dict(checkpoint)

    print(f'loading pretrained synth model from {opt.saved_synth_model}')
    checkpoint = torch.load(opt.saved_synth_model, map_location=lambda storage, loc: storage)
    
    cEncoder.load_state_dict(checkpoint['cEncoder'])
    styleModel.load_state_dict(checkpoint['styleModel'])
    g_ema.load_state_dict(checkpoint['g_ema'])
     
    
    iCntr=0
    evalCntr=0
    fCntr=0
    
    valMSE=0.0
    valSSIM=0.0
    valPSNR=0.0
    c1_s1_input_correct=0.0
    c2_s1_gen_correct=0.0
    c1_s1_input_ed_correct=0.0
    c2_s1_gen_ed_correct=0.0
    
    
    ims, txts = [], []
    
    for vCntr, (image_input_tensors, image_output_tensors, labels_gt, labels_z_c, labelSynthImg, synth_z_c, input_1_shape, input_2_shape) in enumerate(valid_loader):
        print(vCntr)

        if opt.debugFlag and vCntr >10:
            break  
        
        image_input_tensors = image_input_tensors.to(device)
        image_output_tensors = image_output_tensors.to(device)

        if opt.realVaData and opt.outPairFile=="":
            # pdb.set_trace()
            labels_z_c, synth_z_c = next(text_loader)
        
        labelSynthImg = labelSynthImg.to(device)
        synth_z_c = synth_z_c.to(device)
        synth_z_c = synth_z_c[:labelSynthImg.shape[0]]
        
        text_z_c, length_z_c = converter.encode(labels_z_c, batch_max_length=opt.batch_max_length)
        text_gt, length_gt = converter.encode(labels_gt, batch_max_length=opt.batch_max_length)
        
        # print(labels_z_c)

        cEncoder.eval()
        styleModel.eval()
        g_ema.eval()

        with torch.no_grad():
            if opt.cEncode == 'mlp':    
                z_c_code = cEncoder(text_z_c)
                z_gt_code = cEncoder(text_gt)
            elif opt.cEncode == 'cnn':    
                z_c_code = cEncoder(synth_z_c)
                z_gt_code = cEncoder(labelSynthImg)
                
            style = styleModel(image_input_tensors)
            
            if opt.noiseConcat or opt.zAlone:
                style = mixing_noise(opt.batch_size, opt.latent, opt.mixing, device, style)
            else:
                style = [style]
            
            fake_img_c1_s1, _ = g_ema(style, z_gt_code, input_is_latent=opt.input_latent)
            fake_img_c2_s1, _ = g_ema(style, z_c_code, input_is_latent=opt.input_latent)

            currBatchSize = fake_img_c1_s1.shape[0]
            # text_for_pred = torch.LongTensor(opt.batch_size, opt.batch_max_length + 1).fill_(0).to(device)
            length_for_pred = torch.IntTensor([opt.batch_max_length] * currBatchSize).to(device)
            #Run OCR prediction
            if 'CTC' in opt.Prediction:
                preds = ocrModel(fake_img_c1_s1, text_gt, is_train=False, inAct = opt.taskActivation)
                preds_size = torch.IntTensor([preds.size(1)] * currBatchSize)
                _, preds_index = preds.max(2)
                preds_str_fake_img_c1_s1 = converter.decode(preds_index.data, preds_size.data)

                preds = ocrModel(fake_img_c2_s1, text_z_c, is_train=False, inAct = opt.taskActivation)
                preds_size = torch.IntTensor([preds.size(1)] * currBatchSize)
                _, preds_index = preds.max(2)
                preds_str_fake_img_c2_s1 = converter.decode(preds_index.data, preds_size.data)

                preds = ocrModel(image_input_tensors, text_gt, is_train=False)
                preds_size = torch.IntTensor([preds.size(1)] * image_input_tensors.shape[0])
                _, preds_index = preds.max(2)
                preds_str_gt_1 = converter.decode(preds_index.data, preds_size.data)

                preds = ocrModel(image_output_tensors, text_z_c, is_train=False)
                preds_size = torch.IntTensor([preds.size(1)] * image_output_tensors.shape[0])
                _, preds_index = preds.max(2)
                preds_str_gt_2 = converter.decode(preds_index.data, preds_size.data)

            else:
                
                preds = ocrModel(fake_img_c1_s1, text_gt[:, :-1], is_train=False, inAct = opt.taskActivation)  # align with Attention.forward
                _, preds_index = preds.max(2)
                preds_str_fake_img_c1_s1 = converter.decode(preds_index, length_for_pred)
                for idx, pred in enumerate(preds_str_fake_img_c1_s1):
                    pred_EOS = pred.find('[s]')
                    preds_str_fake_img_c1_s1[idx] = pred[:pred_EOS]  # prune after "end of sentence" token ([s])
                
                preds = ocrModel(fake_img_c2_s1, text_z_c[:, :-1], is_train=False, inAct = opt.taskActivation)  # align with Attention.forward
                _, preds_index = preds.max(2)
                preds_str_fake_img_c2_s1 = converter.decode(preds_index, length_for_pred)
                for idx, pred in enumerate(preds_str_fake_img_c2_s1):
                    pred_EOS = pred.find('[s]')
                    preds_str_fake_img_c2_s1[idx] = pred[:pred_EOS]  # prune after "end of sentence" token ([s])
                

                preds = ocrModel(image_input_tensors, text_gt[:, :-1], is_train=False)  # align with Attention.forward
                _, preds_index = preds.max(2)
                preds_str_gt_1 = converter.decode(preds_index, length_for_pred)
                for idx, pred in enumerate(preds_str_gt_1):
                    pred_EOS = pred.find('[s]')
                    preds_str_gt_1[idx] = pred[:pred_EOS]  # prune after "end of sentence" token ([s])

                preds = ocrModel(image_output_tensors, text_z_c[:, :-1], is_train=False)  # align with Attention.forward
                _, preds_index = preds.max(2)
                preds_str_gt_2 = converter.decode(preds_index, length_for_pred)
                for idx, pred in enumerate(preds_str_gt_2):
                    pred_EOS = pred.find('[s]')
                    preds_str_gt_2[idx] = pred[:pred_EOS]  # prune after "end of sentence" token ([s])

        pathPrefix = os.path.join(opt.valDir, opt.exp_iter)
        os.makedirs(os.path.join(pathPrefix), exist_ok=True)
        
        for trImgCntr in range(image_output_tensors.shape[0]):
            
            if opt.outPairFile!="":
                labelId = 'label-' + valid_loader.dataset.pairId[fCntr] + '-' + str(fCntr)
            else:
                labelId = 'label-%09d' % valid_loader.dataset.filtered_index_list[fCntr]
            #evaluations
            valRange = (-1,+1)
            # pdb.set_trace()
            # inpTensor = skimage.img_as_ubyte(resize(tensor2im(image_input_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0])))
            # gtTensor = skimage.img_as_ubyte(resize(tensor2im(image_output_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0])))
            # predTensor = skimage.img_as_ubyte(resize(tensor2im(fake_img_c2_s1[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0])))
            
            # inpTensor = resize(tensor2im(image_input_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0]), anti_aliasing=True)
            # gtTensor = resize(tensor2im(image_output_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0]), anti_aliasing=True)
            # predTensor = resize(tensor2im(fake_img_c2_s1[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0]), anti_aliasing=True)

            inpTensor = F.interpolate(image_input_tensors[trImgCntr].unsqueeze(0),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0]))
            gtTensor = F.interpolate(image_output_tensors[trImgCntr].unsqueeze(0),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0]))
            predTensor = F.interpolate(fake_img_c2_s1[trImgCntr].unsqueeze(0),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0]))
            predGTTensor = F.interpolate(fake_img_c1_s1[trImgCntr].unsqueeze(0),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0]))

            # inpTensor = cv2.resize(tensor2im(image_input_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),tuple(input_1_shape[trImgCntr]))
            # gtTensor = cv2.resize(tensor2im(image_output_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),tuple(input_1_shape[trImgCntr]))
            # predTensor = cv2.resize(tensor2im(fake_img_c2_s1[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),tuple(input_1_shape[trImgCntr]))
            # predGTTensor = cv2.resize(tensor2im(fake_img_c1_s1[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),tuple(input_1_shape[trImgCntr]))

            # inpTensor = cv2.medianBlur(cv2.resize(tensor2im(image_input_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),tuple(input_1_shape[trImgCntr])),5)
            # gtTensor = cv2.medianBlur(cv2.resize(tensor2im(image_output_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),tuple(input_1_shape[trImgCntr])),5)
            # predTensor = cv2.medianBlur(cv2.resize(tensor2im(fake_img_c2_s1[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),tuple(input_1_shape[trImgCntr])),5)
            # pdb.set_trace()

            if not(opt.realVaData):
                evalMSE = mean_squared_error(tensor2im(gtTensor.squeeze())/255, tensor2im(predTensor.squeeze())/255)
                # evalMSE = mean_squared_error(gtTensor/255, predTensor/255)
                evalSSIM = structural_similarity(tensor2im(gtTensor.squeeze())/255, tensor2im(predTensor.squeeze())/255, multichannel=True)
                # evalSSIM = structural_similarity(gtTensor/255, predTensor/255, multichannel=True)
                evalPSNR = peak_signal_noise_ratio(tensor2im(gtTensor.squeeze())/255, tensor2im(predTensor.squeeze())/255)
                # evalPSNR = peak_signal_noise_ratio(gtTensor/255, predTensor/255)
                # print(evalMSE,evalSSIM,evalPSNR)

                valMSE+=evalMSE
                valSSIM+=evalSSIM
                valPSNR+=evalPSNR

            #ocr accuracy
            # for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob):
            c1_s1_input_gt = labels_gt[trImgCntr]
            c1_s1_input_ocr = preds_str_gt_1[trImgCntr]
            c2_s1_gen_gt = labels_z_c[trImgCntr]
            c2_s1_gen_ocr = preds_str_fake_img_c2_s1[trImgCntr]

            if c1_s1_input_gt == c1_s1_input_ocr:
                c1_s1_input_correct += 1
            if c2_s1_gen_gt == c2_s1_gen_ocr:
                c2_s1_gen_correct += 1

            # ICDAR2019 Normalized Edit Distance
            if len(c1_s1_input_gt) == 0 or len(c1_s1_input_ocr) == 0:
                c1_s1_input_ed_correct += 0
            elif len(c1_s1_input_gt) > len(c1_s1_input_ocr):
                c1_s1_input_ed_correct += 1 - edit_distance(c1_s1_input_ocr, c1_s1_input_gt) / len(c1_s1_input_gt)
            else:
                c1_s1_input_ed_correct += 1 - edit_distance(c1_s1_input_ocr, c1_s1_input_gt) / len(c1_s1_input_ocr)
            
            if len(c2_s1_gen_gt) == 0 or len(c2_s1_gen_ocr) == 0:
                c2_s1_gen_ed_correct += 0
            elif len(c2_s1_gen_gt) > len(c2_s1_gen_ocr):
                c2_s1_gen_ed_correct += 1 - edit_distance(c2_s1_gen_ocr, c2_s1_gen_gt) / len(c2_s1_gen_gt)
            else:
                c2_s1_gen_ed_correct += 1 - edit_distance(c2_s1_gen_ocr, c2_s1_gen_gt) / len(c2_s1_gen_ocr)
            
            evalCntr+=1
            
            #save generated images
            if opt.visFlag and iCntr>500:
                pass
            else:
                try:
                    if iCntr == 0:
                        # update website
                        webpage = html.HTML(pathPrefix, 'Experiment name = %s' % 'Test')
                        webpage.add_header('Testing iteration')
                        
                    iCntr += 1
                    img_path_c1_s1 = os.path.join(labelId+'_pred_val_c1_s1_'+labels_gt[trImgCntr]+'_ocr_'+preds_str_fake_img_c1_s1[trImgCntr]+'.png')
                    img_path_gt_1 = os.path.join(labelId+'_gt_val_c1_s1_'+labels_gt[trImgCntr]+'_ocr_'+preds_str_gt_1[trImgCntr]+'.png')
                    img_path_gt_2 = os.path.join(labelId+'_gt_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_gt_2[trImgCntr]+'.png')
                    img_path_c2_s1 = os.path.join(labelId+'_pred_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_fake_img_c2_s1[trImgCntr]+'.png')
                    
                    ims.append([img_path_gt_1, img_path_c1_s1, img_path_gt_2, img_path_c2_s1])

                    content_c1_s1 = 'PSTYLE-1 '+'Text-1:' + labels_gt[trImgCntr]+' OCR:' + preds_str_fake_img_c1_s1[trImgCntr]
                    content_gt_1 = 'OSTYLE-1 '+'GT:' + labels_gt[trImgCntr]+' OCR:' + preds_str_gt_1[trImgCntr]
                    content_gt_2 = 'OSTYLE-1 '+'GT:' + labels_z_c[trImgCntr]+' OCR:'+preds_str_gt_2[trImgCntr]
                    content_c2_s1 = 'PSTYLE-1 '+'Text-2:' + labels_z_c[trImgCntr]+' OCR:'+preds_str_fake_img_c2_s1[trImgCntr]
                    
                    txts.append([content_gt_1, content_c1_s1, content_gt_2, content_c2_s1])
                    
                    utils.save_image(predGTTensor,os.path.join(pathPrefix,labelId+'_pred_val_c1_s1_'+labels_gt[trImgCntr]+'_ocr_'+preds_str_fake_img_c1_s1[trImgCntr]+'.png'),nrow=1,normalize=True,range=(-1, 1))
                    # cv2.imwrite(os.path.join(pathPrefix,labelId+'_pred_val_c1_s1_'+labels_gt[trImgCntr]+'_ocr_'+preds_str_fake_img_c1_s1[trImgCntr]+'.png'), predGTTensor)
                    
                    # pdb.set_trace()
                    if not opt.zAlone:
                        utils.save_image(inpTensor,os.path.join(pathPrefix,labelId+'_gt_val_c1_s1_'+labels_gt[trImgCntr]+'_ocr_'+preds_str_gt_1[trImgCntr]+'.png'),nrow=1,normalize=True,range=(-1, 1))
                        utils.save_image(gtTensor,os.path.join(pathPrefix,labelId+'_gt_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_gt_2[trImgCntr]+'.png'),nrow=1,normalize=True,range=(-1, 1))
                        utils.save_image(predTensor,os.path.join(pathPrefix,labelId+'_pred_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_fake_img_c2_s1[trImgCntr]+'.png'),nrow=1,normalize=True,range=(-1, 1)) 
                        
                        # imsave(os.path.join(pathPrefix,labelId+'_gt_val_c1_s1_'+labels_gt[trImgCntr]+'_ocr_'+preds_str_gt_1[trImgCntr]+'.png'), inpTensor)
                        # imsave(os.path.join(pathPrefix,labelId+'_gt_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_gt_2[trImgCntr]+'.png'), gtTensor)
                        # imsave(os.path.join(pathPrefix,labelId+'_pred_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_fake_img_c2_s1[trImgCntr]+'.png'), predTensor) 

                        # cv2.imwrite(os.path.join(pathPrefix,labelId+'_gt_val_c1_s1_'+labels_gt[trImgCntr]+'_ocr_'+preds_str_gt_1[trImgCntr]+'.png'), inpTensor)
                        # cv2.imwrite(os.path.join(pathPrefix,labelId+'_gt_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_gt_2[trImgCntr]+'.png'), gtTensor)
                        # cv2.imwrite(os.path.join(pathPrefix,labelId+'_pred_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_fake_img_c2_s1[trImgCntr]+'.png'), predTensor) 
                except:
                    print('Warning while saving validation image')
            
            fCntr += 1
    
    webpage.add_images(ims, txts, width=256, realFlag=opt.realVaData)    
    webpage.save()
    
    avg_valMSE = valMSE/float(evalCntr)
    avg_valSSIM = valSSIM/float(evalCntr)
    avg_valPSNR = valPSNR/float(evalCntr)
    avg_c1_s1_input_wer = c1_s1_input_correct/float(evalCntr)
    avg_c2_s1_gen_wer = c2_s1_gen_correct/float(evalCntr)
    avg_c1_s1_input_cer = c1_s1_input_ed_correct/float(evalCntr)
    avg_c2_s1_gen_cer = c2_s1_gen_ed_correct/float(evalCntr)

    # if not(opt.realVaData):
    with open(os.path.join(opt.exp_dir,opt.exp_name,'log_test.txt'), 'a') as log:
        # training loss and validation loss
        if opt.realVaData:
            loss_log = f'Test Input Word Acc: {avg_c1_s1_input_wer:0.5f}, Test Gen Word Acc: {avg_c2_s1_gen_wer:0.5f}, Test Input Char Acc: {avg_c1_s1_input_cer:0.5f}, Test Gen Char Acc: {avg_c2_s1_gen_cer:0.5f}'
        else:
            loss_log = f'Test MSE: {avg_valMSE:0.5f}, Test SSIM: {avg_valSSIM:0.5f}, Test PSNR: {avg_valPSNR:0.5f}, Test Input Word Acc: {avg_c1_s1_input_wer:0.5f}, Test Gen Word Acc: {avg_c2_s1_gen_wer:0.5f}, Test Input Char Acc: {avg_c1_s1_input_cer:0.5f}, Test Gen Char Acc: {avg_c2_s1_gen_cer:0.5f}'
        
        print(loss_log)
        log.write(loss_log+"\n")
Ejemplo n.º 15
0
def runDeepTextNet(segmentedImagesList):
    opt = argparse.Namespace(FeatureExtraction='ResNet',
                             PAD=False,
                             Prediction='Attn',
                             SequenceModeling='BiLSTM',
                             Transformation='TPS',
                             batch_max_length=25,
                             batch_size=192,
                             character='0123456789abcdefghijklmnopqrstuvwxyz',
                             hidden_size=256,
                             image_folder='demo_image/',
                             imgH=32,
                             imgW=100,
                             input_channel=1,
                             num_class=38,
                             num_fiducial=20,
                             num_gpu=0,
                             output_channel=512,
                             rgb=False,
                             saved_model='TPS-ResNet-BiLSTM-Attn.pth',
                             sensitive=False,
                             workers=4)

    model = Model(opt)
    model = torch.nn.DataParallel(model).to('cpu')
    directory = "TPS-ResNet-BiLSTM-Attn.pth"
    model.load_state_dict(torch.load(directory, map_location='cpu'))

    converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)
    if opt.rgb:
        opt.input_channel = 3

    AlignCollate_demo = AlignCollate(imgH=opt.imgH,
                                     imgW=opt.imgW,
                                     keep_ratio_with_pad=opt.PAD)
    demo_data = RawDataset(root=segmentedImagesList, opt=opt)  # use RawDataset
    demo_loader = torch.utils.data.DataLoader(demo_data,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=int(opt.workers),
                                              collate_fn=AlignCollate_demo,
                                              pin_memory=True)

    # predict
    model.eval()

    out_preds_texts = []
    for image_tensors, image_path_list in demo_loader:
        batch_size = image_tensors.size(0)
        image = image_tensors.to(device)
        # For max length prediction
        length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                          batch_size).to(device)
        text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                         1).fill_(0).to(device)
        preds = model(image, text_for_pred, is_train=False)
        # select max probabilty (greedy decoding) then decode index to character
        _, preds_index = preds.max(2)
        preds_str = converter.decode(preds_index, length_for_pred)
        preds_prob = F.softmax(preds, dim=2)
        preds_max_prob, _ = preds_prob.max(dim=2)
        for img_name, pred, pred_max_prob in zip(image_path_list, preds_str,
                                                 preds_max_prob):
            if 'Attn' in opt.Prediction:
                pred_EOS = pred.find('[s]')
                pred = pred[:
                            pred_EOS]  # prune after "end of sentence" token ([s])
                pred_max_prob = pred_max_prob[:pred_EOS]

            # calculate confidence score (= multiply of pred_max_prob)
            confidence_score = pred_max_prob.cumprod(dim=0)[-1]
            # print(pred)
            out_preds_texts.append(pred)
    # print(out_preds_texts)

    sentence_out = [' '.join(out_preds_texts)]
    return (sentence_out)
def train(opt):
    plotDir = os.path.join(opt.exp_dir, opt.exp_name, 'plots')
    if not os.path.exists(plotDir):
        os.makedirs(plotDir)

    lib.print_model_settings(locals().copy())
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')

    log = open(os.path.join(opt.exp_dir, opt.exp_name, 'log_dataset.txt'), 'a')
    AlignCollate_valid = AlignPairCollate(imgH=opt.imgH,
                                          imgW=opt.imgW,
                                          keep_ratio_with_pad=opt.PAD)

    train_dataset, train_dataset_log = hierarchical_dataset(
        root=opt.train_data, opt=opt)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    log.write(train_dataset_log)
    print('-' * 80)

    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        False,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()

    if 'Attn' in opt.Prediction:
        converter = AttnLabelConverter(opt.character)
    else:
        converter = CTCLabelConverter(opt.character)

    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3

    ocrModel = ModelV1(opt)
    styleModel = StyleTensorEncoder(input_dim=opt.input_channel)
    genModel = AdaIN_Tensor_WordGenerator(opt)
    disModel = MsImageDisV2(opt)

    if opt.contentLoss == 'vis' or opt.contentLoss == 'seq':
        ocrCriterion = torch.nn.L1Loss()
    else:
        if 'CTC' in opt.Prediction:
            ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
        else:
            ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
                device)  # ignore [GO] token = ignore index 0

    vggRecCriterion = torch.nn.L1Loss()
    vggModel = VGGPerceptualLossModel(models.vgg19(pretrained=True),
                                      vggRecCriterion)

    print('model input parameters', opt.imgH, opt.imgW, opt.input_channel,
          opt.output_channel, opt.hidden_size, opt.num_class,
          opt.batch_max_length)

    #  weight initialization
    for currModel in [styleModel, genModel, disModel]:
        for name, param in currModel.named_parameters():
            if 'localization_fc2' in name:
                print(f'Skip {name} as it is already initialized')
                continue
            try:
                if 'bias' in name:
                    init.constant_(param, 0.0)
                elif 'weight' in name:
                    init.kaiming_normal_(param)
            except Exception as e:  # for batchnorm.
                if 'weight' in name:
                    param.data.fill_(1)
                continue

    styleModel = torch.nn.DataParallel(styleModel).to(device)
    styleModel.train()

    genModel = torch.nn.DataParallel(genModel).to(device)
    genModel.train()

    disModel = torch.nn.DataParallel(disModel).to(device)
    disModel.train()

    vggModel = torch.nn.DataParallel(vggModel).to(device)
    vggModel.eval()

    ocrModel = torch.nn.DataParallel(ocrModel).to(device)
    ocrModel.module.Transformation.eval()
    ocrModel.module.FeatureExtraction.eval()
    ocrModel.module.AdaptiveAvgPool.eval()
    # ocrModel.module.SequenceModeling.eval()
    ocrModel.module.Prediction.eval()

    if opt.modelFolderFlag:
        if len(
                glob.glob(
                    os.path.join(opt.exp_dir, opt.exp_name,
                                 "iter_*_synth.pth"))) > 0:
            opt.saved_synth_model = glob.glob(
                os.path.join(opt.exp_dir, opt.exp_name,
                             "iter_*_synth.pth"))[-1]

    if opt.saved_ocr_model != '' and opt.saved_ocr_model != 'None':
        print(f'loading pretrained ocr model from {opt.saved_ocr_model}')
        checkpoint = torch.load(opt.saved_ocr_model)
        ocrModel.load_state_dict(checkpoint)

    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        print(f'loading pretrained synth model from {opt.saved_synth_model}')
        checkpoint = torch.load(opt.saved_synth_model)

        styleModel.load_state_dict(checkpoint['styleModel'])
        genModel.load_state_dict(checkpoint['genModel'])
        disModel.load_state_dict(checkpoint['disModel'])

    if opt.imgReconLoss == 'l1':
        recCriterion = torch.nn.L1Loss()
    elif opt.imgReconLoss == 'ssim':
        recCriterion = ssim
    elif opt.imgReconLoss == 'ms-ssim':
        recCriterion = msssim

    if opt.styleLoss == 'l1':
        styleRecCriterion = torch.nn.L1Loss()
    elif opt.styleLoss == 'triplet':
        styleRecCriterion = torch.nn.TripletMarginLoss(
            margin=opt.tripletMargin, p=1)
    #for validation; check only positive pairs
    styleTestRecCriterion = torch.nn.L1Loss()

    # loss averager
    loss_avg = Averager()
    loss_avg_dis = Averager()
    loss_avg_gen = Averager()
    loss_avg_imgRecon = Averager()
    loss_avg_vgg_per = Averager()
    loss_avg_vgg_sty = Averager()
    loss_avg_ocr = Averager()

    ##---------------------------------------##
    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, styleModel.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    for p in filter(lambda p: p.requires_grad, genModel.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable style and generator params num : ', sum(params_num))

    # setup optimizer
    if opt.optim == 'adam':
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, opt.beta2),
                               weight_decay=opt.weight_decay)
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps,
                                   weight_decay=opt.weight_decay)
    print("SynthOptimizer:")
    print(optimizer)

    #filter parameters for Dis training
    dis_filtered_parameters = []
    dis_params_num = []
    for p in filter(lambda p: p.requires_grad, disModel.parameters()):
        dis_filtered_parameters.append(p)
        dis_params_num.append(np.prod(p.size()))
    print('Dis Trainable params num : ', sum(dis_params_num))

    # setup optimizer
    if opt.optim == 'adam':
        dis_optimizer = optim.Adam(dis_filtered_parameters,
                                   lr=opt.lr,
                                   betas=(opt.beta1, opt.beta2),
                                   weight_decay=opt.weight_decay)
    else:
        dis_optimizer = optim.Adadelta(dis_filtered_parameters,
                                       lr=opt.lr,
                                       rho=opt.rho,
                                       eps=opt.eps,
                                       weight_decay=opt.weight_decay)
    print("DisOptimizer:")
    print(dis_optimizer)
    ##---------------------------------------##
    """ final options """
    with open(os.path.join(opt.exp_dir, opt.exp_name, 'opt.txt'),
              'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0

    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        try:
            start_iter = int(
                opt.saved_synth_model.split('_')[-2].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    #get schedulers
    scheduler = get_scheduler(optimizer, opt)
    dis_scheduler = get_scheduler(dis_optimizer, opt)

    start_time = time.time()
    iteration = start_iter
    cntr = 0

    while (True):
        # train part
        if opt.lr_policy != "None":
            scheduler.step()
            dis_scheduler.step()

        image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(
            train_loader).next()

        cntr += 1

        image_input_tensors = image_input_tensors.to(device)
        image_gt_tensors = image_gt_tensors.to(device)
        batch_size = image_input_tensors.size(0)

        text_1, length_1 = converter.encode(
            labels_1, batch_max_length=opt.batch_max_length)
        text_2, length_2 = converter.encode(
            labels_2, batch_max_length=opt.batch_max_length)

        #forward pass from style and word generator
        style = styleModel(image_input_tensors)

        images_recon_2 = genModel(style, text_2)

        #Domain discriminator: Dis update
        disModel.zero_grad()
        disCost = opt.disWeight * (disModel.module.calc_dis_loss(
            torch.cat((images_recon_2.detach(), image_input_tensors), dim=1),
            torch.cat((image_gt_tensors, image_input_tensors), dim=1)))

        disCost.backward()
        dis_optimizer.step()
        loss_avg_dis.add(disCost)

        # #[Style Encoder] + [Word Generator] update
        #Adversarial loss
        disGenCost = disModel.module.calc_gen_loss(
            torch.cat((images_recon_2, image_input_tensors), dim=1))

        #Input reconstruction loss
        recCost = recCriterion(images_recon_2, image_gt_tensors)

        #vgg loss
        vggPerCost, vggStyleCost = vggModel(image_gt_tensors, images_recon_2)

        #ocr loss
        text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                         1).fill_(0).to(device)
        length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                          batch_size).to(device)
        if opt.contentLoss == 'vis' or opt.contentLoss == 'seq':
            preds_recon = ocrModel(images_recon_2,
                                   text_for_pred,
                                   is_train=False,
                                   returnFeat=opt.contentLoss)
            preds_gt = ocrModel(image_gt_tensors,
                                text_for_pred,
                                is_train=False,
                                returnFeat=opt.contentLoss)
            ocrCost = ocrCriterion(preds_recon, preds_gt)
        else:
            if 'CTC' in opt.Prediction:
                preds_recon = ocrModel(images_recon_2,
                                       text_for_pred,
                                       is_train=False)
                preds_o = preds_recon.deepcopy()[:, :text_1.shape[1] - 1, :]
                preds_size = torch.IntTensor([preds_recon.size(1)] *
                                             batch_size)
                preds_recon = preds_recon.log_softmax(2).permute(1, 0, 2)
                ocrCost = ocrCriterion(preds_recon, text_2, preds_size,
                                       length_2)

                #predict ocr recognition on generated images
                preds_o_size = torch.IntTensor([preds_o.size(1)] * batch_size)
                _, preds_o_index = preds_o.max(2)
                labels_o_ocr = converter.decode(preds_o_index.data,
                                                preds_o_size.data)

                #predict ocr recognition on gt style images
                preds_s = ocrModel(image_input_tensors,
                                   text_for_pred,
                                   is_train=False)
                preds_s = preds_s[:, :text_1.shape[1] - 1, :]
                preds_s_size = torch.IntTensor([preds_s.size(1)] * batch_size)
                _, preds_s_index = preds_s.max(2)
                labels_s_ocr = converter.decode(preds_s_index.data,
                                                preds_s_size.data)

                #predict ocr recognition on gt stylecontent images
                preds_sc = ocrModel(image_input_tensors,
                                    text_for_pred,
                                    is_train=False)
                preds_sc = preds_sc[:, :text_2.shape[1] - 1, :]
                preds_sc_size = torch.IntTensor([preds_sc.size(1)] *
                                                batch_size)
                _, preds_sc_index = preds_sc.max(2)
                labels_sc_ocr = converter.decode(preds_sc_index.data,
                                                 preds_sc_size.data)

            else:
                preds_recon = ocrModel(
                    images_recon_2, text_for_pred[:, :-1],
                    is_train=False)  # align with Attention.forward
                target_2 = text_2[:, 1:]  # without [GO] Symbol
                ocrCost = ocrCriterion(
                    preds_recon.view(-1, preds_recon.shape[-1]),
                    target_2.contiguous().view(-1))

                #predict ocr recognition on generated images
                _, preds_o_index = preds_recon.max(2)
                labels_o_ocr = converter.decode(preds_o_index, length_for_pred)
                for idx, pred in enumerate(labels_o_ocr):
                    pred_EOS = pred.find('[s]')
                    labels_o_ocr[
                        idx] = pred[:
                                    pred_EOS]  # prune after "end of sentence" token ([s])

                #predict ocr recognition on gt style images
                preds_s = ocrModel(image_input_tensors,
                                   text_for_pred,
                                   is_train=False)
                _, preds_s_index = preds_s.max(2)
                labels_s_ocr = converter.decode(preds_s_index, length_for_pred)
                for idx, pred in enumerate(labels_s_ocr):
                    pred_EOS = pred.find('[s]')
                    labels_s_ocr[
                        idx] = pred[:
                                    pred_EOS]  # prune after "end of sentence" token ([s])

                #predict ocr recognition on gt stylecontent images
                preds_sc = ocrModel(image_gt_tensors,
                                    text_for_pred,
                                    is_train=False)
                _, preds_sc_index = preds_sc.max(2)
                labels_sc_ocr = converter.decode(preds_sc_index,
                                                 length_for_pred)
                for idx, pred in enumerate(labels_sc_ocr):
                    pred_EOS = pred.find('[s]')
                    labels_sc_ocr[
                        idx] = pred[:
                                    pred_EOS]  # prune after "end of sentence" token ([s])

        cost = opt.reconWeight * recCost + opt.disWeight * disGenCost + opt.vggPerWeight * vggPerCost + opt.vggStyWeight * vggStyleCost + opt.ocrWeight * ocrCost

        styleModel.zero_grad()
        genModel.zero_grad()
        disModel.zero_grad()
        vggModel.zero_grad()
        ocrModel.zero_grad()

        cost.backward()
        optimizer.step()
        loss_avg.add(cost)

        #Individual losses
        loss_avg_gen.add(opt.disWeight * disGenCost)
        loss_avg_imgRecon.add(opt.reconWeight * recCost)
        loss_avg_vgg_per.add(opt.vggPerWeight * vggPerCost)
        loss_avg_vgg_sty.add(opt.vggStyWeight * vggStyleCost)
        loss_avg_ocr.add(opt.ocrWeight * ocrCost)

        # validation part
        if (
                iteration + 1
        ) % opt.valInterval == 0 or iteration == 0:  # To see training progress, we also conduct validation when 'iteration == 0'

            #Save training images
            os.makedirs(os.path.join(opt.exp_dir, opt.exp_name, 'trainImages',
                                     str(iteration)),
                        exist_ok=True)
            for trImgCntr in range(batch_size):
                try:
                    if opt.contentLoss == 'vis' or opt.contentLoss == 'seq':
                        save_image(
                            tensor2im(image_input_tensors[trImgCntr].detach()),
                            os.path.join(
                                opt.exp_dir, opt.exp_name, 'trainImages',
                                str(iteration),
                                str(trImgCntr) + '_sInput_' +
                                labels_1[trImgCntr] + '.png'))
                        save_image(
                            tensor2im(image_gt_tensors[trImgCntr].detach()),
                            os.path.join(
                                opt.exp_dir, opt.exp_name, 'trainImages',
                                str(iteration),
                                str(trImgCntr) + '_csGT_' +
                                labels_2[trImgCntr] + '.png'))
                        save_image(
                            tensor2im(images_recon_2[trImgCntr].detach()),
                            os.path.join(
                                opt.exp_dir, opt.exp_name, 'trainImages',
                                str(iteration),
                                str(trImgCntr) + '_csRecon_' +
                                labels_2[trImgCntr] + '.png'))
                    else:
                        save_image(
                            tensor2im(image_input_tensors[trImgCntr].detach()),
                            os.path.join(
                                opt.exp_dir, opt.exp_name, 'trainImages',
                                str(iteration),
                                str(trImgCntr) + '_sInput_' +
                                labels_1[trImgCntr] + '_' +
                                labels_s_ocr[trImgCntr] + '.png'))
                        save_image(
                            tensor2im(image_gt_tensors[trImgCntr].detach()),
                            os.path.join(
                                opt.exp_dir, opt.exp_name, 'trainImages',
                                str(iteration),
                                str(trImgCntr) + '_csGT_' +
                                labels_2[trImgCntr] + '_' +
                                labels_sc_ocr[trImgCntr] + '.png'))
                        save_image(
                            tensor2im(images_recon_2[trImgCntr].detach()),
                            os.path.join(
                                opt.exp_dir, opt.exp_name, 'trainImages',
                                str(iteration),
                                str(trImgCntr) + '_csRecon_' +
                                labels_2[trImgCntr] + '_' +
                                labels_o_ocr[trImgCntr] + '.png'))
                except:
                    print('Warning while saving training image')

            elapsed_time = time.time() - start_time
            # for log

            with open(os.path.join(opt.exp_dir, opt.exp_name, 'log_train.txt'),
                      'a') as log:
                styleModel.eval()
                genModel.eval()
                disModel.eval()

                with torch.no_grad():
                    valid_loss, infer_time, length_of_data = validation_synth_v4(
                        iteration, styleModel, genModel, vggModel, ocrModel,
                        disModel, recCriterion, ocrCriterion, valid_loader,
                        converter, opt)

                styleModel.train()
                genModel.train()
                disModel.train()

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}] Train Synth loss: {loss_avg.val():0.5f}, \
                    Train Dis loss: {loss_avg_dis.val():0.5f}, Train Gen loss: {loss_avg_gen.val():0.5f},\
                    Train ImgRecon loss: {loss_avg_imgRecon.val():0.5f}, Train VGG-Per loss: {loss_avg_vgg_per.val():0.5f},\
                    Train VGG-Sty loss: {loss_avg_vgg_sty.val():0.5f}, Train OCR loss: {loss_avg_ocr.val():0.5f}, Valid Synth loss: {valid_loss[0]:0.5f}, \
                    Valid Dis loss: {valid_loss[1]:0.5f}, Valid Gen loss: {valid_loss[2]:0.5f}, \
                    Valid ImgRecon loss: {valid_loss[3]:0.5f}, Valid VGG-Per loss: {valid_loss[4]:0.5f}, \
                    Valid VGG-Sty loss: {valid_loss[5]:0.5f}, Valid OCR loss: {valid_loss[6]:0.5f}, Elapsed_time: {elapsed_time:0.5f}'

                #plotting
                lib.plot.plot(os.path.join(plotDir, 'Train-Synth-Loss'),
                              loss_avg.val().item())
                lib.plot.plot(os.path.join(plotDir, 'Train-Dis-Loss'),
                              loss_avg_dis.val().item())

                lib.plot.plot(os.path.join(plotDir, 'Train-Gen-Loss'),
                              loss_avg_gen.val().item())
                lib.plot.plot(os.path.join(plotDir, 'Train-ImgRecon1-Loss'),
                              loss_avg_imgRecon.val().item())
                lib.plot.plot(os.path.join(plotDir, 'Train-VGG-Per-Loss'),
                              loss_avg_vgg_per.val().item())
                lib.plot.plot(os.path.join(plotDir, 'Train-VGG-Sty-Loss'),
                              loss_avg_vgg_sty.val().item())
                lib.plot.plot(os.path.join(plotDir, 'Train-OCR-Loss'),
                              loss_avg_ocr.val().item())

                lib.plot.plot(os.path.join(plotDir, 'Valid-Synth-Loss'),
                              valid_loss[0].item())
                lib.plot.plot(os.path.join(plotDir, 'Valid-Dis-Loss'),
                              valid_loss[1].item())

                lib.plot.plot(os.path.join(plotDir, 'Valid-Gen-Loss'),
                              valid_loss[2].item())
                lib.plot.plot(os.path.join(plotDir, 'Valid-ImgRecon1-Loss'),
                              valid_loss[3].item())
                lib.plot.plot(os.path.join(plotDir, 'Valid-VGG-Per-Loss'),
                              valid_loss[4].item())
                lib.plot.plot(os.path.join(plotDir, 'Valid-VGG-Sty-Loss'),
                              valid_loss[5].item())
                lib.plot.plot(os.path.join(plotDir, 'Valid-OCR-Loss'),
                              valid_loss[6].item())

                print(loss_log)

                loss_avg.reset()
                loss_avg_dis.reset()

                loss_avg_gen.reset()
                loss_avg_imgRecon.reset()
                loss_avg_vgg_per.reset()
                loss_avg_vgg_sty.reset()
                loss_avg_ocr.reset()

            lib.plot.flush()

        lib.plot.tick()

        # save model per 1e+5 iter.
        if (iteration) % 1e+4 == 0:
            torch.save(
                {
                    'styleModel': styleModel.state_dict(),
                    'genModel': genModel.state_dict(),
                    'disModel': disModel.state_dict()
                },
                os.path.join(opt.exp_dir, opt.exp_name,
                             'iter_' + str(iteration + 1) + '_synth.pth'))

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1