Ejemplo n.º 1
0
def test(opt):
    lib.print_model_settings(locals().copy())
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    log = open(os.path.join(opt.exp_dir, opt.exp_name, 'log_dataset.txt'), 'a')
    AlignCollate_valid = AlignPairCollate(imgH=opt.imgH,
                                          imgW=opt.imgW,
                                          keep_ratio_with_pad=opt.PAD)

    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        False,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True,
        drop_last=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()

    if 'Attn' in opt.Prediction:
        converter = AttnLabelConverter(opt.character)
    else:
        converter = CTCLabelConverter(opt.character)

    opt.num_class = len(converter.character)

    g_ema = styleGANGen(opt.size,
                        opt.latent,
                        opt.n_mlp,
                        opt.num_class,
                        channel_multiplier=opt.channel_multiplier)
    g_ema = torch.nn.DataParallel(g_ema).to(device)
    g_ema.eval()

    print('model input parameters', opt.imgH, opt.imgW, opt.input_channel,
          opt.output_channel, opt.hidden_size, opt.num_class,
          opt.batch_max_length)

    ## Loading pre-trained files
    if opt.modelFolderFlag:
        if len(
                glob.glob(
                    os.path.join(opt.exp_dir, opt.exp_name,
                                 "iter_*_synth.pth"))) > 0:
            opt.saved_synth_model = glob.glob(
                os.path.join(opt.exp_dir, opt.exp_name,
                             "iter_*_synth.pth"))[-1]

    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        print(f'loading pretrained synth model from {opt.saved_synth_model}')
        checkpoint = torch.load(opt.saved_synth_model)

        g_ema.load_state_dict(checkpoint['g_ema'], strict=False)

    # pdb.set_trace()
    if opt.truncation < 1:
        with torch.no_grad():
            mean_latent = g_ema.module.mean_latent_content(opt.truncation_mean)

    else:
        mean_latent = None

    cntr = 0

    for i, (image_input_tensors, image_gt_tensors, labels_1,
            labels_2) in enumerate(valid_loader):
        print(i, len(valid_loader))
        image_input_tensors = image_input_tensors.to(device)
        image_gt_tensors = image_gt_tensors.to(device)
        batch_size = image_input_tensors.size(0)

        text_1, length_1 = converter.encode(
            labels_1, batch_max_length=opt.batch_max_length)
        text_2, length_2 = converter.encode(
            labels_2, batch_max_length=opt.batch_max_length)

        #forward pass from style and word generator
        if opt.fixedStyleBatch:
            fixstyle = []
            # pdb.set_trace()
            style = mixing_noise(1, opt.latent, opt.mixing, device)
            fixstyle.append(style[0].repeat(opt.batch_size, 1))
            if len(style) > 1:
                fixstyle.append(style[1].repeat(opt.batch_size, 1))
            style = fixstyle
        else:
            style = mixing_noise(opt.batch_size, opt.latent, opt.mixing,
                                 device)

        if 'CTC' in opt.Prediction:
            images_recon_2, _ = g_ema(style,
                                      text_2,
                                      input_is_latent=opt.input_latent,
                                      inject_index=5,
                                      truncation=opt.truncation,
                                      truncation_latent=mean_latent,
                                      randomize_noise=False)
        else:
            images_recon_2, _ = g_ema(style,
                                      text_2[:, 1:-1],
                                      input_is_latent=opt.input_latent,
                                      inject_index=5,
                                      truncation=opt.truncation,
                                      truncation_latent=mean_latent,
                                      randomize_noise=False)

        # os.makedirs(os.path.join(opt.valDir,str(iteration)), exist_ok=True)
        for trImgCntr in range(batch_size):
            try:
                save_image(
                    tensor2im(image_input_tensors[trImgCntr].detach()),
                    os.path.join(
                        opt.valDir,
                        str(cntr) + '_' + str(trImgCntr) + '_sInput_' +
                        labels_1[trImgCntr] + '.png'))
                save_image(
                    tensor2im(image_gt_tensors[trImgCntr].detach()),
                    os.path.join(
                        opt.valDir,
                        str(cntr) + '_' + str(trImgCntr) + '_csGT_' +
                        labels_2[trImgCntr] + '.png'))
                save_image(
                    tensor2im(images_recon_2[trImgCntr].detach()),
                    os.path.join(
                        opt.valDir,
                        str(cntr) + '_' + str(trImgCntr) + '_csRecon_' +
                        labels_2[trImgCntr] + '.png'))
            except:
                print('Warning while saving training image')
        cntr += 1
def train(opt):
    lib.print_model_settings(locals().copy())
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    log = open(os.path.join(opt.exp_dir, opt.exp_name, 'log_dataset.txt'), 'a')
    AlignCollate_valid = AlignPairCollate(imgH=opt.imgH,
                                          imgW=opt.imgW,
                                          keep_ratio_with_pad=opt.PAD)

    train_dataset, train_dataset_log = hierarchical_dataset(
        root=opt.train_data, opt=opt)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True,
        drop_last=True)
    log.write(train_dataset_log)
    print('-' * 80)

    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        False,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True,
        drop_last=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()

    if 'Attn' in opt.Prediction:
        converter = AttnLabelConverter(opt.character)
    else:
        converter = CTCLabelConverter(opt.character)

    opt.num_class = len(converter.character)

    ocrModel = ModelV1(opt)
    # styleModel = StyleTensorEncoder(input_dim=opt.input_channel)
    # genModel = AdaIN_Tensor_WordGenerator(opt)
    # disModel = MsImageDisV2(opt)

    styleModel = StyleLatentEncoder(input_dim=opt.input_channel, norm='none')
    mixModel = Mixer(opt, nblk=3, dim=opt.latent)
    genModel = styleGANGen(opt.size,
                           opt.latent,
                           opt.n_mlp,
                           channel_multiplier=opt.channel_multiplier)
    disModel = styleGANDis(opt.size,
                           channel_multiplier=opt.channel_multiplier,
                           input_dim=opt.input_channel * 2)
    g_ema = styleGANGen(opt.size,
                        opt.latent,
                        opt.n_mlp,
                        channel_multiplier=opt.channel_multiplier)
    accumulate(g_ema, genModel, 0)

    #  weight initialization
    for currModel in [styleModel, mixModel]:
        for name, param in currModel.named_parameters():
            if 'localization_fc2' in name:
                print(f'Skip {name} as it is already initialized')
                continue
            try:
                if 'bias' in name:
                    init.constant_(param, 0.0)
                elif 'weight' in name:
                    init.kaiming_normal_(param)
            except Exception as e:  # for batchnorm.
                if 'weight' in name:
                    param.data.fill_(1)
                continue

    if opt.contentLoss == 'vis' or opt.contentLoss == 'seq':
        ocrCriterion = torch.nn.L1Loss()
    else:
        if 'CTC' in opt.Prediction:
            ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
        else:
            ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
                device)  # ignore [GO] token = ignore index 0

    vggRecCriterion = torch.nn.L1Loss()
    vggModel = VGGPerceptualLossModel(models.vgg19(pretrained=True),
                                      vggRecCriterion)

    print('model input parameters', opt.imgH, opt.imgW, opt.input_channel,
          opt.output_channel, opt.hidden_size, opt.num_class,
          opt.batch_max_length)

    styleModel = torch.nn.DataParallel(styleModel).to(device)
    styleModel.train()

    mixModel = torch.nn.DataParallel(mixModel).to(device)
    mixModel.train()

    genModel = torch.nn.DataParallel(genModel).to(device)
    g_ema = torch.nn.DataParallel(g_ema).to(device)
    genModel.eval()
    g_ema.eval()

    disModel = torch.nn.DataParallel(disModel).to(device)
    disModel.eval()

    vggModel = torch.nn.DataParallel(vggModel).to(device)
    vggModel.eval()

    ocrModel = torch.nn.DataParallel(ocrModel).to(device)
    ocrModel.module.Transformation.eval()
    ocrModel.module.FeatureExtraction.eval()
    ocrModel.module.AdaptiveAvgPool.eval()
    # ocrModel.module.SequenceModeling.eval()
    ocrModel.module.Prediction.eval()

    g_reg_ratio = opt.g_reg_every / (opt.g_reg_every + 1)
    d_reg_ratio = opt.d_reg_every / (opt.d_reg_every + 1)

    optimizer = optim.Adam(
        list(styleModel.parameters()) + list(mixModel.parameters()),
        lr=opt.lr * g_reg_ratio,
        betas=(0**g_reg_ratio, 0.99**g_reg_ratio),
    )
    dis_optimizer = optim.Adam(
        disModel.parameters(),
        lr=opt.lr * d_reg_ratio,
        betas=(0**d_reg_ratio, 0.99**d_reg_ratio),
    )

    ## Loading pre-trained files
    if opt.modelFolderFlag:
        if len(
                glob.glob(
                    os.path.join(opt.exp_dir, opt.exp_name,
                                 "iter_*_synth.pth"))) > 0:
            opt.saved_synth_model = glob.glob(
                os.path.join(opt.exp_dir, opt.exp_name,
                             "iter_*_synth.pth"))[-1]

    if opt.saved_ocr_model != '' and opt.saved_ocr_model != 'None':
        print(f'loading pretrained ocr model from {opt.saved_ocr_model}')
        checkpoint = torch.load(opt.saved_ocr_model)
        ocrModel.load_state_dict(checkpoint)

    if opt.saved_gen_model != '' and opt.saved_gen_model != 'None':
        print(f'loading pretrained gen model from {opt.saved_gen_model}')
        checkpoint = torch.load(opt.saved_gen_model,
                                map_location=lambda storage, loc: storage)
        genModel.module.load_state_dict(checkpoint['g'])
        g_ema.module.load_state_dict(checkpoint['g_ema'])

    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        print(f'loading pretrained synth model from {opt.saved_synth_model}')
        checkpoint = torch.load(opt.saved_synth_model)

        styleModel.load_state_dict(checkpoint['styleModel'])
        mixModel.load_state_dict(checkpoint['mixModel'])
        genModel.load_state_dict(checkpoint['genModel'])
        g_ema.load_state_dict(checkpoint['g_ema'])
        disModel.load_state_dict(checkpoint['disModel'])

        optimizer.load_state_dict(checkpoint["optimizer"])
        dis_optimizer.load_state_dict(checkpoint["dis_optimizer"])

    if opt.imgReconLoss == 'l1':
        recCriterion = torch.nn.L1Loss()
    elif opt.imgReconLoss == 'ssim':
        recCriterion = ssim
    elif opt.imgReconLoss == 'ms-ssim':
        recCriterion = msssim

    # loss averager
    loss_avg = Averager()
    loss_avg_dis = Averager()
    loss_avg_gen = Averager()
    loss_avg_imgRecon = Averager()
    loss_avg_vgg_per = Averager()
    loss_avg_vgg_sty = Averager()
    loss_avg_ocr = Averager()

    log_r1_val = Averager()
    log_avg_path_loss_val = Averager()
    log_avg_mean_path_length_avg = Averager()
    log_ada_aug_p = Averager()
    """ final options """
    with open(os.path.join(opt.exp_dir, opt.exp_name, 'opt.txt'),
              'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0

    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        try:
            start_iter = int(
                opt.saved_synth_model.split('_')[-2].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    #get schedulers
    scheduler = get_scheduler(optimizer, opt)
    # dis_scheduler = get_scheduler(dis_optimizer,opt)

    start_time = time.time()
    iteration = start_iter
    cntr = 0

    mean_path_length = 0
    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    accum = 0.5**(32 / (10 * 1000))
    ada_augment = torch.tensor([0.0, 0.0], device=device)
    ada_aug_p = opt.augment_p if opt.augment_p > 0 else 0.0
    ada_aug_step = opt.ada_target / opt.ada_length
    r_t_stat = 0

    # sample_z = torch.randn(opt.n_sample, opt.latent, device=device)

    while (True):
        # print(cntr)
        # train part
        if opt.lr_policy != "None":
            scheduler.step()

        disCost = torch.tensor(0.0)
        loss_dict["d"] = disCost * opt.disWeight
        loss_dict["real_score"] = torch.tensor(0.0)
        loss_dict["fake_score"] = torch.tensor(0.0)

        loss_avg_dis.add(disCost)
        loss_dict["r1"] = torch.tensor(0.0)

        # #[Style Encoder] + [Word Generator] update
        image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(
            train_loader).next()

        image_input_tensors = image_input_tensors.to(device)
        image_gt_tensors = image_gt_tensors.to(device)
        batch_size = image_input_tensors.size(0)

        text_1, length_1 = converter.encode(
            labels_1, batch_max_length=opt.batch_max_length)
        text_2, length_2 = converter.encode(
            labels_2, batch_max_length=opt.batch_max_length)

        style = styleModel(image_input_tensors).squeeze(2).squeeze(2)
        scInput = mixModel(style, text_2)
        images_recon_2, _ = genModel([scInput],
                                     input_is_latent=opt.input_latent)

        disGenCost = torch.tensor(0.0)

        loss_dict["g"] = disGenCost

        #Input reconstruction loss
        recCost = recCriterion(images_recon_2, image_gt_tensors)

        #vgg loss
        vggPerCost, vggStyleCost = vggModel(image_gt_tensors, images_recon_2)

        #ocr loss
        text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                         1).fill_(0).to(device)
        length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                          batch_size).to(device)
        if opt.contentLoss == 'vis' or opt.contentLoss == 'seq':
            preds_recon = ocrModel(images_recon_2,
                                   text_for_pred,
                                   is_train=False,
                                   returnFeat=opt.contentLoss)
            preds_gt = ocrModel(image_gt_tensors,
                                text_for_pred,
                                is_train=False,
                                returnFeat=opt.contentLoss)
            ocrCost = ocrCriterion(preds_recon, preds_gt)
        else:
            if 'CTC' in opt.Prediction:
                preds_recon = ocrModel(images_recon_2,
                                       text_for_pred,
                                       is_train=False)
                preds_o = preds_recon.deepcopy()[:, :text_1.shape[1] - 1, :]
                preds_size = torch.IntTensor([preds_recon.size(1)] *
                                             batch_size)
                preds_recon = preds_recon.log_softmax(2).permute(1, 0, 2)
                ocrCost = ocrCriterion(preds_recon, text_2, preds_size,
                                       length_2)

                #predict ocr recognition on generated images
                preds_o_size = torch.IntTensor([preds_o.size(1)] * batch_size)
                _, preds_o_index = preds_o.max(2)
                labels_o_ocr = converter.decode(preds_o_index.data,
                                                preds_o_size.data)

                #predict ocr recognition on gt style images
                preds_s = ocrModel(image_input_tensors,
                                   text_for_pred,
                                   is_train=False)
                preds_s = preds_s[:, :text_1.shape[1] - 1, :]
                preds_s_size = torch.IntTensor([preds_s.size(1)] * batch_size)
                _, preds_s_index = preds_s.max(2)
                labels_s_ocr = converter.decode(preds_s_index.data,
                                                preds_s_size.data)

                #predict ocr recognition on gt stylecontent images
                preds_sc = ocrModel(image_input_tensors,
                                    text_for_pred,
                                    is_train=False)
                preds_sc = preds_sc[:, :text_2.shape[1] - 1, :]
                preds_sc_size = torch.IntTensor([preds_sc.size(1)] *
                                                batch_size)
                _, preds_sc_index = preds_sc.max(2)
                labels_sc_ocr = converter.decode(preds_sc_index.data,
                                                 preds_sc_size.data)

            else:
                preds_recon = ocrModel(
                    images_recon_2, text_for_pred[:, :-1],
                    is_train=False)  # align with Attention.forward
                target_2 = text_2[:, 1:]  # without [GO] Symbol
                ocrCost = ocrCriterion(
                    preds_recon.view(-1, preds_recon.shape[-1]),
                    target_2.contiguous().view(-1))

                #predict ocr recognition on generated images
                _, preds_o_index = preds_recon.max(2)
                labels_o_ocr = converter.decode(preds_o_index, length_for_pred)
                for idx, pred in enumerate(labels_o_ocr):
                    pred_EOS = pred.find('[s]')
                    labels_o_ocr[
                        idx] = pred[:
                                    pred_EOS]  # prune after "end of sentence" token ([s])

                #predict ocr recognition on gt style images
                preds_s = ocrModel(image_input_tensors,
                                   text_for_pred,
                                   is_train=False)
                _, preds_s_index = preds_s.max(2)
                labels_s_ocr = converter.decode(preds_s_index, length_for_pred)
                for idx, pred in enumerate(labels_s_ocr):
                    pred_EOS = pred.find('[s]')
                    labels_s_ocr[
                        idx] = pred[:
                                    pred_EOS]  # prune after "end of sentence" token ([s])

                #predict ocr recognition on gt stylecontent images
                preds_sc = ocrModel(image_gt_tensors,
                                    text_for_pred,
                                    is_train=False)
                _, preds_sc_index = preds_sc.max(2)
                labels_sc_ocr = converter.decode(preds_sc_index,
                                                 length_for_pred)
                for idx, pred in enumerate(labels_sc_ocr):
                    pred_EOS = pred.find('[s]')
                    labels_sc_ocr[
                        idx] = pred[:
                                    pred_EOS]  # prune after "end of sentence" token ([s])

        cost = opt.reconWeight * recCost + opt.disWeight * disGenCost + opt.vggPerWeight * vggPerCost + opt.vggStyWeight * vggStyleCost + opt.ocrWeight * ocrCost

        styleModel.zero_grad()
        genModel.zero_grad()
        mixModel.zero_grad()
        disModel.zero_grad()
        vggModel.zero_grad()
        ocrModel.zero_grad()

        cost.backward()
        optimizer.step()
        loss_avg.add(cost)

        loss_dict["path"] = torch.tensor(0.0)
        loss_dict["path_length"] = torch.tensor(0.0)

        # accumulate(g_ema, genModel, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()

        #Individual losses
        loss_avg_gen.add(opt.disWeight * disGenCost)
        loss_avg_imgRecon.add(opt.reconWeight * recCost)
        loss_avg_vgg_per.add(opt.vggPerWeight * vggPerCost)
        loss_avg_vgg_sty.add(opt.vggStyWeight * vggStyleCost)
        loss_avg_ocr.add(opt.ocrWeight * ocrCost)

        log_r1_val.add(loss_reduced["path"])
        log_avg_path_loss_val.add(loss_reduced["path"])
        log_avg_mean_path_length_avg.add(torch.tensor(0.0))
        log_ada_aug_p.add(torch.tensor(ada_aug_p))

        if get_rank() == 0:
            # pbar.set_description(
            #     (
            #         f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
            #         f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
            #         f"augment: {ada_aug_p:.4f}"
            #     )
            # )

            if wandb and opt.wandb:
                wandb.log({
                    "Generator": g_loss_val,
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1": r1_val,
                    "Path Length Regularization": path_loss_val,
                    "Mean Path Length": mean_path_length,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                    "Path Length": path_length_val,
                })

        # validation part
        if (
                iteration + 1
        ) % opt.valInterval == 0 or iteration == 0:  # To see training progress, we also conduct validation when 'iteration == 0'

            #Save training images
            # images_recon_2, _ = g_ema([scInput[:,:opt.latent],scInput[:,opt.latent:]])
            # images_recon_2, _ = g_ema(scInput)

            os.makedirs(os.path.join(opt.trainDir, str(iteration)),
                        exist_ok=True)
            for trImgCntr in range(batch_size):
                try:
                    if opt.contentLoss == 'vis' or opt.contentLoss == 'seq':
                        save_image(
                            tensor2im(image_input_tensors[trImgCntr].detach()),
                            os.path.join(
                                opt.trainDir, str(iteration),
                                str(trImgCntr) + '_sInput_' +
                                labels_1[trImgCntr] + '.png'))
                        save_image(
                            tensor2im(image_gt_tensors[trImgCntr].detach()),
                            os.path.join(
                                opt.trainDir, str(iteration),
                                str(trImgCntr) + '_csGT_' +
                                labels_2[trImgCntr] + '.png'))
                        save_image(
                            tensor2im(images_recon_2[trImgCntr].detach()),
                            os.path.join(
                                opt.trainDir, str(iteration),
                                str(trImgCntr) + '_csRecon_' +
                                labels_2[trImgCntr] + '.png'))
                    else:
                        save_image(
                            tensor2im(image_input_tensors[trImgCntr].detach()),
                            os.path.join(
                                opt.trainDir, str(iteration),
                                str(trImgCntr) + '_sInput_' +
                                labels_1[trImgCntr] + '_' +
                                labels_s_ocr[trImgCntr] + '.png'))
                        save_image(
                            tensor2im(image_gt_tensors[trImgCntr].detach()),
                            os.path.join(
                                opt.trainDir, str(iteration),
                                str(trImgCntr) + '_csGT_' +
                                labels_2[trImgCntr] + '_' +
                                labels_sc_ocr[trImgCntr] + '.png'))
                        save_image(
                            tensor2im(images_recon_2[trImgCntr].detach()),
                            os.path.join(
                                opt.trainDir, str(iteration),
                                str(trImgCntr) + '_csRecon_' +
                                labels_2[trImgCntr] + '_' +
                                labels_o_ocr[trImgCntr] + '.png'))
                except:
                    print('Warning while saving training image')

            elapsed_time = time.time() - start_time
            # for log

            with open(os.path.join(opt.exp_dir, opt.exp_name, 'log_train.txt'),
                      'a') as log:
                styleModel.eval()
                genModel.eval()
                g_ema.eval()
                mixModel.eval()
                disModel.eval()

                with torch.no_grad():
                    valid_loss, infer_time, length_of_data = validation_synth_v5(
                        iteration, styleModel, genModel, mixModel, vggModel,
                        ocrModel, disModel, recCriterion, ocrCriterion,
                        valid_loader, converter, opt)

                styleModel.train()
                genModel.eval()
                mixModel.train()
                disModel.eval()

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}] Train Synth loss: {loss_avg.val():0.5f}, \
                    Train Dis loss: {loss_avg_dis.val():0.5f}, Train Gen loss: {loss_avg_gen.val():0.5f},\
                    Train ImgRecon loss: {loss_avg_imgRecon.val():0.5f}, Train VGG-Per loss: {loss_avg_vgg_per.val():0.5f},\
                    Train VGG-Sty loss: {loss_avg_vgg_sty.val():0.5f}, Train OCR loss: {loss_avg_ocr.val():0.5f}, \
                    Train R1-val loss: {log_r1_val.val():0.5f}, Train avg-path-loss: {log_avg_path_loss_val.val():0.5f}, \
                    Train mean-path-length loss: {log_avg_mean_path_length_avg.val():0.5f}, Train ada-aug-p: {log_ada_aug_p.val():0.5f}, \
                    Valid Synth loss: {valid_loss[0]:0.5f}, \
                    Valid Dis loss: {valid_loss[1]:0.5f}, Valid Gen loss: {valid_loss[2]:0.5f}, \
                    Valid ImgRecon loss: {valid_loss[3]:0.5f}, Valid VGG-Per loss: {valid_loss[4]:0.5f}, \
                    Valid VGG-Sty loss: {valid_loss[5]:0.5f}, Valid OCR loss: {valid_loss[6]:0.5f}, Elapsed_time: {elapsed_time:0.5f}'

                #plotting
                lib.plot.plot(os.path.join(opt.plotDir, 'Train-Synth-Loss'),
                              loss_avg.val().item())
                lib.plot.plot(os.path.join(opt.plotDir, 'Train-Dis-Loss'),
                              loss_avg_dis.val().item())

                lib.plot.plot(os.path.join(opt.plotDir, 'Train-Gen-Loss'),
                              loss_avg_gen.val().item())
                lib.plot.plot(
                    os.path.join(opt.plotDir, 'Train-ImgRecon1-Loss'),
                    loss_avg_imgRecon.val().item())
                lib.plot.plot(os.path.join(opt.plotDir, 'Train-VGG-Per-Loss'),
                              loss_avg_vgg_per.val().item())
                lib.plot.plot(os.path.join(opt.plotDir, 'Train-VGG-Sty-Loss'),
                              loss_avg_vgg_sty.val().item())
                lib.plot.plot(os.path.join(opt.plotDir, 'Train-OCR-Loss'),
                              loss_avg_ocr.val().item())

                lib.plot.plot(os.path.join(opt.plotDir, 'Train-r1_val'),
                              log_r1_val.val().item())
                lib.plot.plot(os.path.join(opt.plotDir, 'Train-path_loss_val'),
                              log_avg_path_loss_val.val().item())
                lib.plot.plot(
                    os.path.join(opt.plotDir, 'Train-mean_path_length_avg'),
                    log_avg_mean_path_length_avg.val().item())
                lib.plot.plot(os.path.join(opt.plotDir, 'Train-ada_aug_p'),
                              log_ada_aug_p.val().item())

                lib.plot.plot(os.path.join(opt.plotDir, 'Valid-Synth-Loss'),
                              valid_loss[0].item())
                lib.plot.plot(os.path.join(opt.plotDir, 'Valid-Dis-Loss'),
                              valid_loss[1].item())

                lib.plot.plot(os.path.join(opt.plotDir, 'Valid-Gen-Loss'),
                              valid_loss[2].item())
                lib.plot.plot(
                    os.path.join(opt.plotDir, 'Valid-ImgRecon1-Loss'),
                    valid_loss[3].item())
                lib.plot.plot(os.path.join(opt.plotDir, 'Valid-VGG-Per-Loss'),
                              valid_loss[4].item())
                lib.plot.plot(os.path.join(opt.plotDir, 'Valid-VGG-Sty-Loss'),
                              valid_loss[5].item())
                lib.plot.plot(os.path.join(opt.plotDir, 'Valid-OCR-Loss'),
                              valid_loss[6].item())

                print(loss_log)

                loss_avg.reset()
                loss_avg_dis.reset()

                loss_avg_gen.reset()
                loss_avg_imgRecon.reset()
                loss_avg_vgg_per.reset()
                loss_avg_vgg_sty.reset()
                loss_avg_ocr.reset()

                log_r1_val.reset()
                log_avg_path_loss_val.reset()
                log_avg_mean_path_length_avg.reset()
                log_ada_aug_p.reset()

            lib.plot.flush()

        lib.plot.tick()

        # save model per 1e+5 iter.
        if (iteration) % 1e+4 == 0:
            torch.save(
                {
                    'styleModel': styleModel.state_dict(),
                    'mixModel': mixModel.state_dict(),
                    'genModel': genModel.state_dict(),
                    'g_ema': g_ema.state_dict(),
                    'disModel': disModel.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'dis_optimizer': dis_optimizer.state_dict()
                },
                os.path.join(opt.exp_dir, opt.exp_name,
                             'iter_' + str(iteration + 1) + '_synth.pth'))

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
        cntr += 1
Ejemplo n.º 3
0
def test(opt):
    lib.print_model_settings(locals().copy())

    if 'Attn' in opt.Prediction:
        converter = AttnLabelConverter(opt.character)
        text_len = opt.batch_max_length + 2
    else:
        converter = CTCLabelConverter(opt.character)
        text_len = opt.batch_max_length

    opt.classes = converter.character
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)

    valid_dataset = LmdbDataset(root=opt.test_data, opt=opt)
    test_data_sampler = data_sampler(valid_dataset,
                                     shuffle=False,
                                     distributed=False)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        False,  # 'True' to check training progress with validation function.
        sampler=test_data_sampler,
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True,
        drop_last=False)

    print('-' * 80)

    opt.num_class = len(converter.character)

    ocrModel = ModelV1(opt).to(device)

    ## Loading pre-trained files
    print(f'loading pretrained ocr model from {opt.saved_ocr_model}')
    checkpoint = torch.load(opt.saved_ocr_model,
                            map_location=lambda storage, loc: storage)
    ocrModel.load_state_dict(checkpoint)

    evalCntr = 0
    fCntr = 0

    c1_s1_input_correct = 0.0
    c1_s1_input_ed_correct = 0.0
    # pdb.set_trace()

    for vCntr, (image_input_tensors, labels_gt) in enumerate(valid_loader):
        print(vCntr)

        image_input_tensors = image_input_tensors.to(device)
        text_gt, length_gt = converter.encode(
            labels_gt, batch_max_length=opt.batch_max_length)

        with torch.no_grad():
            currBatchSize = image_input_tensors.shape[0]
            # text_for_pred = torch.LongTensor(opt.batch_size, opt.batch_max_length + 1).fill_(0).to(device)
            length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                              currBatchSize).to(device)
            #Run OCR prediction
            if 'CTC' in opt.Prediction:
                preds = ocrModel(image_input_tensors, text_gt, is_train=False)
                preds_size = torch.IntTensor([preds.size(1)] *
                                             image_input_tensors.shape[0])
                _, preds_index = preds.max(2)
                preds_str_gt_1 = converter.decode(preds_index.data,
                                                  preds_size.data)

            else:
                preds = ocrModel(
                    image_input_tensors, text_gt[:, :-1],
                    is_train=False)  # align with Attention.forward
                _, preds_index = preds.max(2)
                preds_str_gt_1 = converter.decode(preds_index, length_for_pred)
                for idx, pred in enumerate(preds_str_gt_1):
                    pred_EOS = pred.find('[s]')
                    preds_str_gt_1[
                        idx] = pred[:
                                    pred_EOS]  # prune after "end of sentence" token ([s])

        for trImgCntr in range(image_input_tensors.shape[0]):
            #ocr accuracy
            # for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob):
            c1_s1_input_gt = labels_gt[trImgCntr]
            c1_s1_input_ocr = preds_str_gt_1[trImgCntr]

            if c1_s1_input_gt == c1_s1_input_ocr:
                c1_s1_input_correct += 1

            # ICDAR2019 Normalized Edit Distance
            if len(c1_s1_input_gt) == 0 or len(c1_s1_input_ocr) == 0:
                c1_s1_input_ed_correct += 0
            elif len(c1_s1_input_gt) > len(c1_s1_input_ocr):
                c1_s1_input_ed_correct += 1 - edit_distance(
                    c1_s1_input_ocr, c1_s1_input_gt) / len(c1_s1_input_gt)
            else:
                c1_s1_input_ed_correct += 1 - edit_distance(
                    c1_s1_input_ocr, c1_s1_input_gt) / len(c1_s1_input_ocr)

            evalCntr += 1

            fCntr += 1

    avg_c1_s1_input_wer = c1_s1_input_correct / float(evalCntr)
    avg_c1_s1_input_cer = c1_s1_input_ed_correct / float(evalCntr)

    # if not(opt.realVaData):
    with open(os.path.join(opt.exp_dir, opt.exp_name, 'log_test.txt'),
              'a') as log:
        # training loss and validation loss

        loss_log = f'Word Acc: {avg_c1_s1_input_wer:0.5f}, Test Input Char Acc: {avg_c1_s1_input_cer:0.5f}'

        print(loss_log)
        log.write(loss_log + "\n")
Ejemplo n.º 4
0
class trainer(object):
    def __init__(self, opt):

        opt.src_select_data = opt.src_select_data.split('-')
        opt.src_batch_ratio = opt.src_batch_ratio.split('-')
        opt.tar_select_data = opt.tar_select_data.split('-')
        opt.tar_batch_ratio = opt.tar_batch_ratio.split('-')
        """ vocab / character number configuration """
        if opt.sensitive:
            # opt.character += 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
            opt.character = string.printable[:
                                             -6]  # same with ASTER setting (use 94 char).

        if opt.char_dict is not None:
            opt.character = load_char_dict(
                opt.char_dict)[3:-2]  # 去除Attention 和 CTC引入的一些特殊符号
        """ model configuration """
        self.converter = AttnLabelConverter(opt.character)
        opt.num_class = len(self.converter.character)

        if opt.rgb:
            opt.input_channel = 3
        self.opt = opt
        print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
              opt.input_channel, opt.output_channel, opt.hidden_size,
              opt.num_class, opt.batch_max_length, opt.Transformation,
              opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)
        self.save_opt_log(opt)

        self.build_model(opt)

    def dataloader(self, opt):
        src_train_data = opt.src_train_data
        src_select_data = opt.src_select_data
        src_batch_ratio = opt.src_batch_ratio
        src_train_dataset = Batch_Balanced_Dataset(opt, src_train_data,
                                                   src_select_data,
                                                   src_batch_ratio)

        tar_train_data = opt.tar_train_data
        tar_select_data = opt.tar_select_data
        tar_batch_ratio = opt.tar_batch_ratio
        tar_train_dataset = Batch_Balanced_Dataset(opt, tar_train_data,
                                                   tar_select_data,
                                                   tar_batch_ratio)

        AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                          imgW=opt.imgW,
                                          keep_ratio_with_pad=opt.PAD)

        valid_dataset = hierarchical_dataset(root=opt.valid_data, opt=opt)
        valid_loader = torch.utils.data.DataLoader(
            valid_dataset,
            batch_size=opt.batch_size,
            shuffle=
            True,  # 'True' to check training progress with validation function.
            num_workers=int(opt.workers),
            collate_fn=AlignCollate_valid,
            pin_memory=True)
        return src_train_dataset, tar_train_dataset, valid_loader

    def _optimizer(self, opt):
        # filter that only require gradient decent
        filtered_parameters = []
        params_num = []
        for p in filter(lambda p: p.requires_grad, self.model.parameters()):
            filtered_parameters.append(p)
            params_num.append(np.prod(p.size()))
        print('Trainable params num : ', sum(params_num))
        # setup optimizer
        if opt.adam:
            self.optimizer = optim.Adam(filtered_parameters,
                                        lr=opt.lr,
                                        betas=(opt.beta1, 0.999))
        else:
            self.optimizer = optim.Adadelta(filtered_parameters,
                                            lr=opt.lr,
                                            rho=opt.rho,
                                            eps=opt.eps)

        print("Optimizer:")
        print(self.optimizer)

    def build_model(self, opt):
        print('-' * 80)
        """ Define Model """
        self.model = Model(opt)

        self.weight_initializer()

        self.model = torch.nn.DataParallel(self.model).to(device)
        """ Define Loss """
        self.criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)
        self.D_criterion = torch.nn.CrossEntropyLoss().to(device)
        """ Trainer """
        self._optimizer(opt)

    def train(self, opt):
        # Add custom dataset add cfgs from da-faster-rcnn
        # Make sure you change the imdb_name in factory.py
        """
        Dummy format:

        args.src_dataset == '$YOUR_DATASET_NAME'
        args.src_imdb_name = '$YOUR_DATASET_NAME_2007_trainval'
        args.src_imdbval_name = '$YOUR_DATASET_NAME_2007_test'
        args.set_cfgs = [...]
        """

        # src, tar dataloaders
        src_dataset, tar_dataset, valid_loader = self.dataloader(opt)
        src_dataset_size = src_dataset.total_data_size
        tar_dataset_size = tar_dataset.total_data_size
        train_size = max([src_dataset_size, tar_dataset_size])

        self.model.train()

        start_iter = 0

        if opt.continue_model != '':
            self.load(opt.continue_model)
            print(" [*] Load SUCCESS")
            # if opt.decay_flag and start_iter > (opt.num_iter // 2):
            #     self.d_image_opt.param_groups[0]['lr'] -= (opt.lr / (opt.num_iter // 2)) * (
            #             start_iter - opt.num_iter // 2)
            #     self.d_inst_opt.param_groups[0]['lr'] -= (opt.lr / (opt.num_iter // 2)) * (
            #             start_iter - opt.num_iter // 2)

        # loss averager
        cls_loss_avg = Averager()
        sim_loss_avg = Averager()
        loss_avg = Averager()

        # training loop
        print('training start !')
        start_time = time.time()
        best_accuracy = -1
        best_norm_ED = 1e+6

        for step in range(start_iter, opt.num_iter + 1):

            src_image, src_labels = src_dataset.get_batch()
            src_image = src_image.to(device)
            src_text, src_length = self.converter.encode(
                src_labels, batch_max_length=opt.batch_max_length)

            tar_image, tar_labels = tar_dataset.get_batch()
            tar_image = tar_image.to(device)
            tar_text, tar_length = self.converter.encode(
                tar_labels, batch_max_length=opt.batch_max_length)

            # Set gradient to zero...
            self.model.zero_grad()

            # Attention # align with Attention.forward
            src_preds, src_global_feature, src_local_feature = self.model(
                src_image, src_text[:, :-1])
            target = src_text[:, 1:]  # without [GO] Symbol
            src_cls_loss = self.criterion(
                src_preds.view(-1, src_preds.shape[-1]),
                target.contiguous().view(-1))

            src_local_feature = src_local_feature.view(
                -1, src_local_feature.shape[-1])
            # TODO
            tar_preds, tar_global_feature, tar_local_feature = self.model(
                tar_image, tar_text[:, :-1], is_train=False)

            tar_local_feature = tar_local_feature.view(
                -1, tar_local_feature.shape[-1])

            d_inst_loss = coral_loss(src_local_feature, src_preds,
                                     tar_local_feature, tar_preds)
            # Add domain loss
            loss = src_cls_loss.mean() + 0.1 * d_inst_loss.mean()
            loss_avg.add(loss)
            cls_loss_avg.add(src_cls_loss)
            sim_loss_avg.add(d_inst_loss)

            # frcnn backward
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                self.model.parameters(),
                opt.grad_clip)  # gradient clipping with 5 (Default)
            # frcnn optimizer update
            self.optimizer.step()

            # validation part
            if step % opt.valInterval == 0:

                elapsed_time = time.time() - start_time
                print(
                    f'[{step}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} CLS_Loss: {cls_loss_avg.val():0.5f} SIMI_Loss: {sim_loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}'
                )
                # for log
                with open(
                        f'./saved_models/{opt.experiment_name}/log_train.txt',
                        'a') as log:
                    log.write(
                        f'[{step}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n'
                    )
                    loss_avg.reset()
                    cls_loss_avg.reset()
                    sim_loss_avg.reset()

                    self.model.eval()
                    with torch.no_grad():
                        valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation(
                            self.model, self.criterion, valid_loader,
                            self.converter, opt)

                    self.print_prediction_result(preds, labels, log)

                    valid_log = f'[{step}/{opt.num_iter}] valid loss: {valid_loss:0.5f}'
                    valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}'
                    print(valid_log)
                    log.write(valid_log + '\n')

                    self.model.train()

                    # keep best accuracy model
                    if current_accuracy > best_accuracy:
                        best_accuracy = current_accuracy
                        save_name = f'./saved_models/{opt.experiment_name}/best_accuracy.pth'
                        self.save(opt, save_name)
                    if current_norm_ED < best_norm_ED:
                        best_norm_ED = current_norm_ED
                        save_name = f'./saved_models/{opt.experiment_name}/best_norm_ED.pth'
                        self.save(opt, save_name)

                    best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}'
                    print(best_model_log)
                    log.write(best_model_log + '\n')

            # save model per 1e+5 iter.
            if (step + 1) % 1e+5 == 0:
                save_name = f'./saved_models/{opt.experiment_name}/iter_{step+1}.pth'
                self.save(opt, save_name)

    def load(self, saved_model):
        params = torch.load(saved_model)

        if 'model' not in params:
            self.model.load_state_dict(params)
        else:
            self.model.load_state_dict(params['model'])

    def save(self, opt, save_name):
        params = {}
        params['model'] = self.model.state_dict()
        # for training
        params['optimizer'] = self.optimizer.state_dict()
        torch.save(params, save_name)
        print('Successfully save model: {}'.format(save_name))

    def weight_initializer(self):
        # weight initialization
        for name, param in self.model.named_parameters():
            if 'localization_fc2' in name:
                print(f'Skip {name} as it is already initialized')
                continue
            try:
                if 'bias' in name:
                    init.constant_(param, 0.0)
                elif 'weight' in name:
                    init.kaiming_normal_(param)
            except Exception as e:  # for batchnorm.
                if 'weight' in name:
                    param.data.fill_(1)
                continue

    def save_opt_log(self, opt):
        """ final options """
        # print(opt)
        with open(f'./saved_models/{opt.experiment_name}/opt.txt',
                  'a') as opt_file:
            opt_log = '------------ Options -------------\n'
            args = vars(opt)
            for k, v in args.items():
                opt_log += f'{str(k)}: {str(v)}\n'
            opt_log += '---------------------------------------\n'
            print(opt_log)
            opt_file.write(opt_log)

    def print_prediction_result(self, preds, labels, fp_log):
        """
         fp-logwenjian
        :param preds:
        :param labels:
        :param fp_log:
        :return:
        """
        for pred, gt in zip(preds[:5], labels[:5]):
            if 'Attn' in opt.Prediction:
                pred = pred[:pred.find('[s]')]
                gt = gt[:gt.find('[s]')]
            print(f'{pred:20s}, gt: {gt:20s},   {str(pred == gt)}')
            fp_log.write(f'{pred:20s}, gt: {gt:20s},   {str(pred == gt)}\n')
Ejemplo n.º 5
0
def train(opt):
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    """ model configuration """
    # CTCLoss
    converter_ctc = CTCLabelConverter(opt.character)
    # Attention
    converter_atten = AttnLabelConverter(opt.character)
    opt.num_class_ctc = len(converter_ctc.character)
    opt.num_class_atten = len(converter_atten.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class_ctc, opt.num_class_atten, opt.batch_max_length,
          opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling,
          opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p_: p_.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    print(optimizer)

    # use fp16 to train
    model = model.to(device)
    if opt.fp16:
        with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log:
            log.write('==> Enable fp16 training' + '\n')
        print('==> Enable fp16 training')
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

    # data parallel for multi-GPU
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model).to(device)
    model.train()
    # for i in model.module.Prediction_atten:
    #     i.to(device)
    # for i in model.module.Feat_Extraction.scr:
    #     i.to(device)
    if opt.saved_model != '':
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_model), strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_model))
    print("Model:")
    print(model)
    """ setup loss """
    criterion_ctc = torch.nn.CTCLoss(zero_infinity=True).to(device)
    criterion_atten = torch.nn.CrossEntropyLoss(ignore_index=0).to(
        device)  # ignore [GO] token = ignore index 0

    # loss averager
    loss_avg = Averager()
    """ final options """
    writer = SummaryWriter(f'./saved_models/{opt.exp_name}')
    # print(opt)
    with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0
    if opt.saved_model != '':
        try:
            start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    iteration = start_iter

    # image_tensors, labels = train_dataset.get_batch()
    while True:
        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        batch_size = image.size(0)
        text_ctc, length_ctc = converter_ctc.encode(
            labels, batch_max_length=opt.batch_max_length)
        text_atten, length_atten = converter_atten.encode(
            labels, batch_max_length=opt.batch_max_length)

        # type tuple; (tensor, list);         text_atten[:, :-1]:align with Attention.forward
        preds_ctc, preds_atten = model(image, text_atten[:, :-1])
        # CTC Loss
        preds_size = torch.IntTensor([preds_ctc.size(1)] * batch_size)
        # _, preds_index = preds_ctc.max(2)
        # preds_str_ctc = converter_ctc.decode(preds_index.data, preds_size.data)
        preds_ctc = preds_ctc.log_softmax(2).permute(1, 0, 2)
        cost_ctc = 0.1 * criterion_ctc(preds_ctc, text_ctc, preds_size,
                                       length_ctc)

        # Attention Loss
        # preds_atten = [i[:, :text_atten.shape[1] - 1, :] for i in preds_atten]
        # # select max probabilty (greedy decoding) then decode index to character
        # preds_index_atten = [i.max(2)[1] for i in preds_atten]
        # length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
        # preds_str_atten = [converter_atten.decode(i, length_for_pred) for i in preds_index_atten]
        # preds_str_atten2 = preds_str_atten
        # preds_str_atten = []
        # for i in preds_str_atten2:  # prune after "end of sentence" token ([s])
        #     temp = []
        #     for j in i:
        #         j = j[:j.find('[s]')]
        #         temp.append(j)
        #     preds_str_atten.append(temp)
        # preds_str_atten = [j[:j.find('[s]')] for i in preds_str_atten for j in i]
        target = text_atten[:, 1:]  # without [GO] Symbol
        # cost_atten = 1.0 * criterion_atten(preds_atten.view(-1, preds_atten.shape[-1]), target.contiguous().view(-1))
        for index, pred in enumerate(preds_atten):
            if index == 0:
                cost_atten = 1.0 * criterion_atten(
                    pred.view(-1, pred.shape[-1]),
                    target.contiguous().view(-1))
            else:
                cost_atten += 1.0 * criterion_atten(
                    pred.view(-1, pred.shape[-1]),
                    target.contiguous().view(-1))
        # cost_atten = [1.0 * criterion_atten(pred.view(-1, pred.shape[-1]), target.contiguous().view(-1)) for pred in
        #               preds_atten]
        # cost_atten = criterion_atten(preds_atten.view(-1, preds_atten.shape[-1]), target.contiguous().view(-1))
        cost = cost_ctc + cost_atten
        writer.add_scalar('loss', cost.item(), global_step=iteration + 1)

        # cost = cost_ctc
        # cost = cost_atten
        if (iteration + 1) % 100 == 0:
            print('\riter: {:4d}\tloss: {:6.3f}\tavg: {:6.3f}'.format(
                iteration + 1, cost.item(), loss_avg.val()),
                  end='\n')
        else:
            print('\riter: {:4d}\tloss: {:6.3f}\tavg: {:6.3f}'.format(
                iteration + 1, cost.item(), loss_avg.val()),
                  end='')
        sys.stdout.flush()
        if cost < 0.001:
            print(f'iter: {iteration + 1}\tloss: {cost}')
            # aaaaaa = 0

        # model.zero_grad()
        optimizer.zero_grad()
        if torch.isnan(cost):
            print(f'iter: {iteration + 1}\tloss: {cost}\t==> Loss is NAN')
            sys.exit()
        elif torch.isinf(cost):
            print(f'iter: {iteration + 1}\tloss: {cost}\t==> Loss is INF')
            sys.exit()
        else:
            if opt.fp16:
                with amp.scale_loss(cost, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                cost.backward()
            torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)
        writer.add_scalar('loss_avg',
                          loss_avg.val(),
                          global_step=iteration + 1)
        # if loss_avg.val() <= 0.6:
        #     opt.grad_clip = 2
        # if loss_avg.val() <= 0.3:
        #     opt.grad_clip = 1

        # validation part
        if iteration == 0 or (
                iteration + 1
        ) % opt.valInterval == 0:  # To see training progress, we also conduct validation when 'iteration == 0'
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.exp_name}/log_train.txt',
                      'a') as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, criterion_atten, valid_loader, converter_atten,
                        opt)
                model.train()
                writer.add_scalar('accuracy',
                                  current_accuracy,
                                  global_step=iteration + 1)

                # training loss and validation loss
                loss_log = f'[{iteration + 1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.exp_name}/best_accuracy.pth')
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.exp_name}/best_norm_ED.pth')
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                for gt, pred, confidence in zip(labels[:5], preds[:5],
                                                confidence_score[:5]):
                    gt = gt[:gt.find('[s]')]
                    pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')

        # save model per 1e+5 iter.
        if (iteration + 1) % 1e+5 == 0:
            torch.save(
                model.state_dict(),
                f'./saved_models/{opt.exp_name}/iter_{iteration + 1}.pth')

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()

        # if (iteration + 1) % opt.valInterval == 0:
        #     print(f'iter: {iteration + 1}\tloss: {cost}')
        iteration += 1
Ejemplo n.º 6
0
def test(opt):
    lib.print_model_settings(locals().copy())

    if 'Attn' in opt.Prediction:
        converter = AttnLabelConverter(opt.character)
        text_len = opt.batch_max_length+2
    else:
        converter = CTCLabelConverter(opt.character)
        text_len = opt.batch_max_length

    opt.classes = converter.character
    
    """ dataset preparation """
    if not opt.data_filtering_off:
        print('Filtering the images containing characters which are not in opt.character')
        print('Filtering the images whose label is longer than opt.batch_max_length')
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    # AlignCollate_valid = AlignPairImgCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
    AlignCollate_valid = AlignPairImgCollate_Test(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
    
    valid_dataset = LmdbTestStyleContentDataset(root=opt.test_data, opt=opt, dataMode=opt.realVaData)
    test_data_sampler = data_sampler(valid_dataset, shuffle=False, distributed=False)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=opt.batch_size, 
        shuffle=False,  # 'True' to check training progress with validation function.
        sampler=test_data_sampler,
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True, drop_last=False)
    
    print('-' * 80)

    AlignCollate_text = AlignSynthTextCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
    text_dataset = text_gen_synth(opt)
    text_data_sampler = data_sampler(text_dataset, shuffle=True, distributed=False)
    text_loader = torch.utils.data.DataLoader(
        text_dataset, batch_size=opt.batch_size,
        shuffle=False,
        sampler=text_data_sampler,
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_text,
        drop_last=True)
    opt.num_class = len(converter.character)

 
    text_loader = sample_data(text_loader, text_data_sampler, False)

    c_code_size = opt.latent
    if opt.cEncode == 'mlp':
        cEncoder = GlobalContentEncoder(opt.num_class, text_len, opt.char_embed_size, c_code_size).to(device)
    elif opt.cEncode == 'cnn':
        if opt.contentNorm == 'in':
            cEncoder = ResNet_StyleExtractor_WIN(1, opt.latent).to(device)
        else:
            cEncoder = ResNet_StyleExtractor(1, opt.latent).to(device)
    if opt.styleNorm == 'in':
        styleModel = ResNet_StyleExtractor_WIN(opt.input_channel, opt.latent).to(device)
    else:
        styleModel = ResNet_StyleExtractor(opt.input_channel, opt.latent).to(device)
    
    ocrModel = ModelV1(opt).to(device)

    g_ema = styleGANGen(opt.size, opt.latent, opt.latent, opt.n_mlp, content_dim=c_code_size, channel_multiplier=opt.channel_multiplier).to(device)
    g_ema.eval()
    
    bestModelError=1e5

    ## Loading pre-trained files
    print(f'loading pretrained ocr model from {opt.saved_ocr_model}')
    checkpoint = torch.load(opt.saved_ocr_model, map_location=lambda storage, loc: storage)
    ocrModel.load_state_dict(checkpoint)

    print(f'loading pretrained synth model from {opt.saved_synth_model}')
    checkpoint = torch.load(opt.saved_synth_model, map_location=lambda storage, loc: storage)
    
    cEncoder.load_state_dict(checkpoint['cEncoder'])
    styleModel.load_state_dict(checkpoint['styleModel'])
    g_ema.load_state_dict(checkpoint['g_ema'])
     
    
    iCntr=0
    evalCntr=0
    fCntr=0
    
    valMSE=0.0
    valSSIM=0.0
    valPSNR=0.0
    c1_s1_input_correct=0.0
    c2_s1_gen_correct=0.0
    c1_s1_input_ed_correct=0.0
    c2_s1_gen_ed_correct=0.0
    
    
    ims, txts = [], []
    
    for vCntr, (image_input_tensors, image_output_tensors, labels_gt, labels_z_c, labelSynthImg, synth_z_c, input_1_shape, input_2_shape) in enumerate(valid_loader):
        print(vCntr)

        if opt.debugFlag and vCntr >10:
            break  
        
        image_input_tensors = image_input_tensors.to(device)
        image_output_tensors = image_output_tensors.to(device)

        if opt.realVaData and opt.outPairFile=="":
            # pdb.set_trace()
            labels_z_c, synth_z_c = next(text_loader)
        
        labelSynthImg = labelSynthImg.to(device)
        synth_z_c = synth_z_c.to(device)
        synth_z_c = synth_z_c[:labelSynthImg.shape[0]]
        
        text_z_c, length_z_c = converter.encode(labels_z_c, batch_max_length=opt.batch_max_length)
        text_gt, length_gt = converter.encode(labels_gt, batch_max_length=opt.batch_max_length)
        
        # print(labels_z_c)

        cEncoder.eval()
        styleModel.eval()
        g_ema.eval()

        with torch.no_grad():
            if opt.cEncode == 'mlp':    
                z_c_code = cEncoder(text_z_c)
                z_gt_code = cEncoder(text_gt)
            elif opt.cEncode == 'cnn':    
                z_c_code = cEncoder(synth_z_c)
                z_gt_code = cEncoder(labelSynthImg)
                
            style = styleModel(image_input_tensors)
            
            if opt.noiseConcat or opt.zAlone:
                style = mixing_noise(opt.batch_size, opt.latent, opt.mixing, device, style)
            else:
                style = [style]
            
            fake_img_c1_s1, _ = g_ema(style, z_gt_code, input_is_latent=opt.input_latent)
            fake_img_c2_s1, _ = g_ema(style, z_c_code, input_is_latent=opt.input_latent)

            currBatchSize = fake_img_c1_s1.shape[0]
            # text_for_pred = torch.LongTensor(opt.batch_size, opt.batch_max_length + 1).fill_(0).to(device)
            length_for_pred = torch.IntTensor([opt.batch_max_length] * currBatchSize).to(device)
            #Run OCR prediction
            if 'CTC' in opt.Prediction:
                preds = ocrModel(fake_img_c1_s1, text_gt, is_train=False, inAct = opt.taskActivation)
                preds_size = torch.IntTensor([preds.size(1)] * currBatchSize)
                _, preds_index = preds.max(2)
                preds_str_fake_img_c1_s1 = converter.decode(preds_index.data, preds_size.data)

                preds = ocrModel(fake_img_c2_s1, text_z_c, is_train=False, inAct = opt.taskActivation)
                preds_size = torch.IntTensor([preds.size(1)] * currBatchSize)
                _, preds_index = preds.max(2)
                preds_str_fake_img_c2_s1 = converter.decode(preds_index.data, preds_size.data)

                preds = ocrModel(image_input_tensors, text_gt, is_train=False)
                preds_size = torch.IntTensor([preds.size(1)] * image_input_tensors.shape[0])
                _, preds_index = preds.max(2)
                preds_str_gt_1 = converter.decode(preds_index.data, preds_size.data)

                preds = ocrModel(image_output_tensors, text_z_c, is_train=False)
                preds_size = torch.IntTensor([preds.size(1)] * image_output_tensors.shape[0])
                _, preds_index = preds.max(2)
                preds_str_gt_2 = converter.decode(preds_index.data, preds_size.data)

            else:
                
                preds = ocrModel(fake_img_c1_s1, text_gt[:, :-1], is_train=False, inAct = opt.taskActivation)  # align with Attention.forward
                _, preds_index = preds.max(2)
                preds_str_fake_img_c1_s1 = converter.decode(preds_index, length_for_pred)
                for idx, pred in enumerate(preds_str_fake_img_c1_s1):
                    pred_EOS = pred.find('[s]')
                    preds_str_fake_img_c1_s1[idx] = pred[:pred_EOS]  # prune after "end of sentence" token ([s])
                
                preds = ocrModel(fake_img_c2_s1, text_z_c[:, :-1], is_train=False, inAct = opt.taskActivation)  # align with Attention.forward
                _, preds_index = preds.max(2)
                preds_str_fake_img_c2_s1 = converter.decode(preds_index, length_for_pred)
                for idx, pred in enumerate(preds_str_fake_img_c2_s1):
                    pred_EOS = pred.find('[s]')
                    preds_str_fake_img_c2_s1[idx] = pred[:pred_EOS]  # prune after "end of sentence" token ([s])
                

                preds = ocrModel(image_input_tensors, text_gt[:, :-1], is_train=False)  # align with Attention.forward
                _, preds_index = preds.max(2)
                preds_str_gt_1 = converter.decode(preds_index, length_for_pred)
                for idx, pred in enumerate(preds_str_gt_1):
                    pred_EOS = pred.find('[s]')
                    preds_str_gt_1[idx] = pred[:pred_EOS]  # prune after "end of sentence" token ([s])

                preds = ocrModel(image_output_tensors, text_z_c[:, :-1], is_train=False)  # align with Attention.forward
                _, preds_index = preds.max(2)
                preds_str_gt_2 = converter.decode(preds_index, length_for_pred)
                for idx, pred in enumerate(preds_str_gt_2):
                    pred_EOS = pred.find('[s]')
                    preds_str_gt_2[idx] = pred[:pred_EOS]  # prune after "end of sentence" token ([s])

        pathPrefix = os.path.join(opt.valDir, opt.exp_iter)
        os.makedirs(os.path.join(pathPrefix), exist_ok=True)
        
        for trImgCntr in range(image_output_tensors.shape[0]):
            
            if opt.outPairFile!="":
                labelId = 'label-' + valid_loader.dataset.pairId[fCntr] + '-' + str(fCntr)
            else:
                labelId = 'label-%09d' % valid_loader.dataset.filtered_index_list[fCntr]
            #evaluations
            valRange = (-1,+1)
            # pdb.set_trace()
            # inpTensor = skimage.img_as_ubyte(resize(tensor2im(image_input_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0])))
            # gtTensor = skimage.img_as_ubyte(resize(tensor2im(image_output_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0])))
            # predTensor = skimage.img_as_ubyte(resize(tensor2im(fake_img_c2_s1[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0])))
            
            # inpTensor = resize(tensor2im(image_input_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0]), anti_aliasing=True)
            # gtTensor = resize(tensor2im(image_output_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0]), anti_aliasing=True)
            # predTensor = resize(tensor2im(fake_img_c2_s1[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0]), anti_aliasing=True)

            inpTensor = F.interpolate(image_input_tensors[trImgCntr].unsqueeze(0),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0]))
            gtTensor = F.interpolate(image_output_tensors[trImgCntr].unsqueeze(0),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0]))
            predTensor = F.interpolate(fake_img_c2_s1[trImgCntr].unsqueeze(0),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0]))
            predGTTensor = F.interpolate(fake_img_c1_s1[trImgCntr].unsqueeze(0),(input_1_shape[trImgCntr][1], input_1_shape[trImgCntr][0]))

            # inpTensor = cv2.resize(tensor2im(image_input_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),tuple(input_1_shape[trImgCntr]))
            # gtTensor = cv2.resize(tensor2im(image_output_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),tuple(input_1_shape[trImgCntr]))
            # predTensor = cv2.resize(tensor2im(fake_img_c2_s1[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),tuple(input_1_shape[trImgCntr]))
            # predGTTensor = cv2.resize(tensor2im(fake_img_c1_s1[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),tuple(input_1_shape[trImgCntr]))

            # inpTensor = cv2.medianBlur(cv2.resize(tensor2im(image_input_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),tuple(input_1_shape[trImgCntr])),5)
            # gtTensor = cv2.medianBlur(cv2.resize(tensor2im(image_output_tensors[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),tuple(input_1_shape[trImgCntr])),5)
            # predTensor = cv2.medianBlur(cv2.resize(tensor2im(fake_img_c2_s1[trImgCntr].clone().clamp_(min=valRange[0], max=valRange[1])),tuple(input_1_shape[trImgCntr])),5)
            # pdb.set_trace()

            if not(opt.realVaData):
                evalMSE = mean_squared_error(tensor2im(gtTensor.squeeze())/255, tensor2im(predTensor.squeeze())/255)
                # evalMSE = mean_squared_error(gtTensor/255, predTensor/255)
                evalSSIM = structural_similarity(tensor2im(gtTensor.squeeze())/255, tensor2im(predTensor.squeeze())/255, multichannel=True)
                # evalSSIM = structural_similarity(gtTensor/255, predTensor/255, multichannel=True)
                evalPSNR = peak_signal_noise_ratio(tensor2im(gtTensor.squeeze())/255, tensor2im(predTensor.squeeze())/255)
                # evalPSNR = peak_signal_noise_ratio(gtTensor/255, predTensor/255)
                # print(evalMSE,evalSSIM,evalPSNR)

                valMSE+=evalMSE
                valSSIM+=evalSSIM
                valPSNR+=evalPSNR

            #ocr accuracy
            # for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob):
            c1_s1_input_gt = labels_gt[trImgCntr]
            c1_s1_input_ocr = preds_str_gt_1[trImgCntr]
            c2_s1_gen_gt = labels_z_c[trImgCntr]
            c2_s1_gen_ocr = preds_str_fake_img_c2_s1[trImgCntr]

            if c1_s1_input_gt == c1_s1_input_ocr:
                c1_s1_input_correct += 1
            if c2_s1_gen_gt == c2_s1_gen_ocr:
                c2_s1_gen_correct += 1

            # ICDAR2019 Normalized Edit Distance
            if len(c1_s1_input_gt) == 0 or len(c1_s1_input_ocr) == 0:
                c1_s1_input_ed_correct += 0
            elif len(c1_s1_input_gt) > len(c1_s1_input_ocr):
                c1_s1_input_ed_correct += 1 - edit_distance(c1_s1_input_ocr, c1_s1_input_gt) / len(c1_s1_input_gt)
            else:
                c1_s1_input_ed_correct += 1 - edit_distance(c1_s1_input_ocr, c1_s1_input_gt) / len(c1_s1_input_ocr)
            
            if len(c2_s1_gen_gt) == 0 or len(c2_s1_gen_ocr) == 0:
                c2_s1_gen_ed_correct += 0
            elif len(c2_s1_gen_gt) > len(c2_s1_gen_ocr):
                c2_s1_gen_ed_correct += 1 - edit_distance(c2_s1_gen_ocr, c2_s1_gen_gt) / len(c2_s1_gen_gt)
            else:
                c2_s1_gen_ed_correct += 1 - edit_distance(c2_s1_gen_ocr, c2_s1_gen_gt) / len(c2_s1_gen_ocr)
            
            evalCntr+=1
            
            #save generated images
            if opt.visFlag and iCntr>500:
                pass
            else:
                try:
                    if iCntr == 0:
                        # update website
                        webpage = html.HTML(pathPrefix, 'Experiment name = %s' % 'Test')
                        webpage.add_header('Testing iteration')
                        
                    iCntr += 1
                    img_path_c1_s1 = os.path.join(labelId+'_pred_val_c1_s1_'+labels_gt[trImgCntr]+'_ocr_'+preds_str_fake_img_c1_s1[trImgCntr]+'.png')
                    img_path_gt_1 = os.path.join(labelId+'_gt_val_c1_s1_'+labels_gt[trImgCntr]+'_ocr_'+preds_str_gt_1[trImgCntr]+'.png')
                    img_path_gt_2 = os.path.join(labelId+'_gt_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_gt_2[trImgCntr]+'.png')
                    img_path_c2_s1 = os.path.join(labelId+'_pred_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_fake_img_c2_s1[trImgCntr]+'.png')
                    
                    ims.append([img_path_gt_1, img_path_c1_s1, img_path_gt_2, img_path_c2_s1])

                    content_c1_s1 = 'PSTYLE-1 '+'Text-1:' + labels_gt[trImgCntr]+' OCR:' + preds_str_fake_img_c1_s1[trImgCntr]
                    content_gt_1 = 'OSTYLE-1 '+'GT:' + labels_gt[trImgCntr]+' OCR:' + preds_str_gt_1[trImgCntr]
                    content_gt_2 = 'OSTYLE-1 '+'GT:' + labels_z_c[trImgCntr]+' OCR:'+preds_str_gt_2[trImgCntr]
                    content_c2_s1 = 'PSTYLE-1 '+'Text-2:' + labels_z_c[trImgCntr]+' OCR:'+preds_str_fake_img_c2_s1[trImgCntr]
                    
                    txts.append([content_gt_1, content_c1_s1, content_gt_2, content_c2_s1])
                    
                    utils.save_image(predGTTensor,os.path.join(pathPrefix,labelId+'_pred_val_c1_s1_'+labels_gt[trImgCntr]+'_ocr_'+preds_str_fake_img_c1_s1[trImgCntr]+'.png'),nrow=1,normalize=True,range=(-1, 1))
                    # cv2.imwrite(os.path.join(pathPrefix,labelId+'_pred_val_c1_s1_'+labels_gt[trImgCntr]+'_ocr_'+preds_str_fake_img_c1_s1[trImgCntr]+'.png'), predGTTensor)
                    
                    # pdb.set_trace()
                    if not opt.zAlone:
                        utils.save_image(inpTensor,os.path.join(pathPrefix,labelId+'_gt_val_c1_s1_'+labels_gt[trImgCntr]+'_ocr_'+preds_str_gt_1[trImgCntr]+'.png'),nrow=1,normalize=True,range=(-1, 1))
                        utils.save_image(gtTensor,os.path.join(pathPrefix,labelId+'_gt_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_gt_2[trImgCntr]+'.png'),nrow=1,normalize=True,range=(-1, 1))
                        utils.save_image(predTensor,os.path.join(pathPrefix,labelId+'_pred_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_fake_img_c2_s1[trImgCntr]+'.png'),nrow=1,normalize=True,range=(-1, 1)) 
                        
                        # imsave(os.path.join(pathPrefix,labelId+'_gt_val_c1_s1_'+labels_gt[trImgCntr]+'_ocr_'+preds_str_gt_1[trImgCntr]+'.png'), inpTensor)
                        # imsave(os.path.join(pathPrefix,labelId+'_gt_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_gt_2[trImgCntr]+'.png'), gtTensor)
                        # imsave(os.path.join(pathPrefix,labelId+'_pred_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_fake_img_c2_s1[trImgCntr]+'.png'), predTensor) 

                        # cv2.imwrite(os.path.join(pathPrefix,labelId+'_gt_val_c1_s1_'+labels_gt[trImgCntr]+'_ocr_'+preds_str_gt_1[trImgCntr]+'.png'), inpTensor)
                        # cv2.imwrite(os.path.join(pathPrefix,labelId+'_gt_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_gt_2[trImgCntr]+'.png'), gtTensor)
                        # cv2.imwrite(os.path.join(pathPrefix,labelId+'_pred_val_c2_s1_'+labels_z_c[trImgCntr]+'_ocr_'+preds_str_fake_img_c2_s1[trImgCntr]+'.png'), predTensor) 
                except:
                    print('Warning while saving validation image')
            
            fCntr += 1
    
    webpage.add_images(ims, txts, width=256, realFlag=opt.realVaData)    
    webpage.save()
    
    avg_valMSE = valMSE/float(evalCntr)
    avg_valSSIM = valSSIM/float(evalCntr)
    avg_valPSNR = valPSNR/float(evalCntr)
    avg_c1_s1_input_wer = c1_s1_input_correct/float(evalCntr)
    avg_c2_s1_gen_wer = c2_s1_gen_correct/float(evalCntr)
    avg_c1_s1_input_cer = c1_s1_input_ed_correct/float(evalCntr)
    avg_c2_s1_gen_cer = c2_s1_gen_ed_correct/float(evalCntr)

    # if not(opt.realVaData):
    with open(os.path.join(opt.exp_dir,opt.exp_name,'log_test.txt'), 'a') as log:
        # training loss and validation loss
        if opt.realVaData:
            loss_log = f'Test Input Word Acc: {avg_c1_s1_input_wer:0.5f}, Test Gen Word Acc: {avg_c2_s1_gen_wer:0.5f}, Test Input Char Acc: {avg_c1_s1_input_cer:0.5f}, Test Gen Char Acc: {avg_c2_s1_gen_cer:0.5f}'
        else:
            loss_log = f'Test MSE: {avg_valMSE:0.5f}, Test SSIM: {avg_valSSIM:0.5f}, Test PSNR: {avg_valPSNR:0.5f}, Test Input Word Acc: {avg_c1_s1_input_wer:0.5f}, Test Gen Word Acc: {avg_c2_s1_gen_wer:0.5f}, Test Input Char Acc: {avg_c1_s1_input_cer:0.5f}, Test Gen Char Acc: {avg_c2_s1_gen_cer:0.5f}'
        
        print(loss_log)
        log.write(loss_log+"\n")
Ejemplo n.º 7
0
def train(opt):
    lib.print_model_settings(locals().copy())

    if 'Attn' in opt.Prediction:
        converter = AttnLabelConverter(opt.character)
        text_len = opt.batch_max_length + 2
    else:
        converter = CTCLabelConverter(opt.character)
        text_len = opt.batch_max_length

    opt.classes = converter.character
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    log = open(os.path.join(opt.exp_dir, opt.exp_name, 'log_dataset.txt'), 'a')
    AlignCollate_valid = AlignPairCollate(imgH=opt.imgH,
                                          imgW=opt.imgW,
                                          keep_ratio_with_pad=opt.PAD)

    train_dataset = LmdbStyleDataset(root=opt.train_data, opt=opt)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=opt.batch_size *
        2,  #*2 to sample different images from training encoder and discriminator real images
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True,
        drop_last=True)

    print('-' * 80)

    valid_dataset = LmdbStyleDataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size *
        2,  #*2 to sample different images from training encoder and discriminator real images
        shuffle=
        False,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True,
        drop_last=True)

    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()

    text_dataset = text_gen(opt)
    text_loader = torch.utils.data.DataLoader(text_dataset,
                                              batch_size=opt.batch_size,
                                              shuffle=True,
                                              num_workers=int(opt.workers),
                                              pin_memory=True,
                                              drop_last=True)
    opt.num_class = len(converter.character)

    c_code_size = opt.latent
    cEncoder = GlobalContentEncoder(opt.num_class, text_len,
                                    opt.char_embed_size, c_code_size)
    ocrModel = ModelV1(opt)

    if opt.zAlone:
        genModel = styleGANGen(opt.size,
                               opt.latent,
                               opt.latent,
                               opt.n_mlp,
                               channel_multiplier=opt.channel_multiplier)
        g_ema = styleGANGen(opt.size,
                            opt.latent,
                            opt.latent,
                            opt.n_mlp,
                            channel_multiplier=opt.channel_multiplier)
    else:
        genModel = styleGANGen(opt.size,
                               opt.latent + c_code_size,
                               opt.latent,
                               opt.n_mlp,
                               channel_multiplier=opt.channel_multiplier)
        g_ema = styleGANGen(opt.size,
                            opt.latent + c_code_size,
                            opt.latent,
                            opt.n_mlp,
                            channel_multiplier=opt.channel_multiplier)
    disEncModel = styleGANDis(opt.size,
                              channel_multiplier=opt.channel_multiplier,
                              input_dim=opt.input_channel,
                              code_s_dim=c_code_size)

    accumulate(g_ema, genModel, 0)

    # uCriterion = torch.nn.MSELoss()
    # sCriterion = torch.nn.MSELoss()
    # if opt.contentLoss == 'vis' or opt.contentLoss == 'seq':
    #     ocrCriterion = torch.nn.L1Loss()
    # else:
    if 'CTC' in opt.Prediction:
        ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        print('Not implemented error')
        sys.exit()
        # ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)  # ignore [GO] token = ignore index 0

    cEncoder = torch.nn.DataParallel(cEncoder).to(device)
    cEncoder.train()
    genModel = torch.nn.DataParallel(genModel).to(device)
    g_ema = torch.nn.DataParallel(g_ema).to(device)
    genModel.train()
    g_ema.eval()

    disEncModel = torch.nn.DataParallel(disEncModel).to(device)
    disEncModel.train()

    ocrModel = torch.nn.DataParallel(ocrModel).to(device)
    if opt.ocrFixed:
        if opt.Transformation == 'TPS':
            ocrModel.module.Transformation.eval()
        ocrModel.module.FeatureExtraction.eval()
        ocrModel.module.AdaptiveAvgPool.eval()
        # ocrModel.module.SequenceModeling.eval()
        ocrModel.module.Prediction.eval()
    else:
        ocrModel.train()

    g_reg_ratio = opt.g_reg_every / (opt.g_reg_every + 1)
    d_reg_ratio = opt.d_reg_every / (opt.d_reg_every + 1)

    optimizer = optim.Adam(
        list(genModel.parameters()) + list(cEncoder.parameters()),
        lr=opt.lr * g_reg_ratio,
        betas=(0**g_reg_ratio, 0.99**g_reg_ratio),
    )
    dis_optimizer = optim.Adam(
        disEncModel.parameters(),
        lr=opt.lr * d_reg_ratio,
        betas=(0**d_reg_ratio, 0.99**d_reg_ratio),
    )

    ocr_optimizer = optim.Adam(
        ocrModel.parameters(),
        lr=opt.lr,
        betas=(0.9, 0.99),
    )

    ## Loading pre-trained files
    if opt.modelFolderFlag:
        if len(
                glob.glob(
                    os.path.join(opt.exp_dir, opt.exp_name,
                                 "iter_*_synth.pth"))) > 0:
            opt.saved_synth_model = glob.glob(
                os.path.join(opt.exp_dir, opt.exp_name,
                             "iter_*_synth.pth"))[-1]

    if opt.saved_ocr_model != '' and opt.saved_ocr_model != 'None':
        print(f'loading pretrained ocr model from {opt.saved_ocr_model}')
        checkpoint = torch.load(opt.saved_ocr_model)
        ocrModel.load_state_dict(checkpoint)

    # if opt.saved_gen_model !='' and opt.saved_gen_model !='None':
    #     print(f'loading pretrained gen model from {opt.saved_gen_model}')
    #     checkpoint = torch.load(opt.saved_gen_model, map_location=lambda storage, loc: storage)
    #     genModel.module.load_state_dict(checkpoint['g'])
    #     g_ema.module.load_state_dict(checkpoint['g_ema'])

    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        print(f'loading pretrained synth model from {opt.saved_synth_model}')
        checkpoint = torch.load(opt.saved_synth_model)

        # styleModel.load_state_dict(checkpoint['styleModel'])
        # mixModel.load_state_dict(checkpoint['mixModel'])
        genModel.load_state_dict(checkpoint['genModel'])
        g_ema.load_state_dict(checkpoint['g_ema'])
        disEncModel.load_state_dict(checkpoint['disEncModel'])
        ocrModel.load_state_dict(checkpoint['ocrModel'])

        optimizer.load_state_dict(checkpoint["optimizer"])
        dis_optimizer.load_state_dict(checkpoint["dis_optimizer"])
        ocr_optimizer.load_state_dict(checkpoint["ocr_optimizer"])

    # if opt.imgReconLoss == 'l1':
    #     recCriterion = torch.nn.L1Loss()
    # elif opt.imgReconLoss == 'ssim':
    #     recCriterion = ssim
    # elif opt.imgReconLoss == 'ms-ssim':
    #     recCriterion = msssim

    # loss averager
    loss_avg_dis = Averager()
    loss_avg_gen = Averager()
    loss_avg_unsup = Averager()
    loss_avg_sup = Averager()
    log_r1_val = Averager()
    log_avg_path_loss_val = Averager()
    log_avg_mean_path_length_avg = Averager()
    log_ada_aug_p = Averager()
    loss_avg_ocr_sup = Averager()
    loss_avg_ocr_unsup = Averager()
    """ final options """
    with open(os.path.join(opt.exp_dir, opt.exp_name, 'opt.txt'),
              'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0

    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        try:
            start_iter = int(
                opt.saved_synth_model.split('_')[-2].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    #get schedulers
    scheduler = get_scheduler(optimizer, opt)
    dis_scheduler = get_scheduler(dis_optimizer, opt)
    ocr_scheduler = get_scheduler(ocr_optimizer, opt)

    start_time = time.time()
    iteration = start_iter
    cntr = 0

    mean_path_length = 0
    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    # loss_dict = {}

    accum = 0.5**(32 / (10 * 1000))
    ada_augment = torch.tensor([0.0, 0.0], device=device)
    ada_aug_p = opt.augment_p if opt.augment_p > 0 else 0.0
    ada_aug_step = opt.ada_target / opt.ada_length
    r_t_stat = 0
    epsilon = 10e-50

    # sample_z = torch.randn(opt.n_sample, opt.latent, device=device)

    while (True):
        # print(cntr)
        # train part
        if opt.lr_policy != "None":
            scheduler.step()
            dis_scheduler.step()
            ocr_scheduler.step()

        image_input_tensors, _, labels, _ = iter(train_loader).next()
        labels_z_c = iter(text_loader).next()

        image_input_tensors = image_input_tensors.to(device)
        gt_image_tensors = image_input_tensors[:opt.batch_size].detach()
        real_image_tensors = image_input_tensors[opt.batch_size:].detach()

        labels_gt = labels[:opt.batch_size]

        requires_grad(cEncoder, False)
        requires_grad(genModel, False)
        requires_grad(disEncModel, True)
        requires_grad(ocrModel, False)

        text_z_c, length_z_c = converter.encode(
            labels_z_c, batch_max_length=opt.batch_max_length)
        text_gt, length_gt = converter.encode(
            labels_gt, batch_max_length=opt.batch_max_length)

        z_c_code = cEncoder(text_z_c)
        style = mixing_noise(z_c_code, opt.batch_size, opt.latent, opt.mixing,
                             device)

        if opt.zAlone:
            #to validate orig style gan results
            newstyle = []
            newstyle.append(style[0][:, :opt.latent])
            if len(style) > 1:
                newstyle.append(style[1][:, :opt.latent])
            style = newstyle

        fake_img, _ = genModel(style, input_is_latent=opt.input_latent)

        # #unsupervised code prediction on generated image
        # u_pred_code = disEncModel(fake_img, mode='enc')
        # uCost = uCriterion(u_pred_code, z_code)

        # #supervised code prediction on gt image
        # s_pred_code = disEncModel(gt_image_tensors, mode='enc')
        # sCost = uCriterion(s_pred_code, gt_phoc_tensors)

        #Domain discriminator
        fake_pred = disEncModel(fake_img)
        real_pred = disEncModel(real_image_tensors)
        disCost = d_logistic_loss(real_pred, fake_pred)

        # dis_cost = disCost + opt.gamma_e*uCost + opt.beta*sCost

        loss_avg_dis.add(disCost)
        # loss_avg_sup.add(opt.beta*sCost)
        # loss_avg_unsup.add(opt.gamma_e * uCost)

        disEncModel.zero_grad()
        disCost.backward()
        dis_optimizer.step()

        d_regularize = cntr % opt.d_reg_every == 0

        if d_regularize:
            real_image_tensors.requires_grad = True
            real_pred = disEncModel(real_image_tensors)

            r1_loss = d_r1_loss(real_pred, real_image_tensors)

            disEncModel.zero_grad()
            (opt.r1 / 2 * r1_loss * opt.d_reg_every +
             0 * real_pred[0]).backward()

            dis_optimizer.step()
        log_r1_val.add(r1_loss)

        # Recognizer update
        if not opt.ocrFixed and not opt.zAlone:
            requires_grad(disEncModel, False)
            requires_grad(ocrModel, True)

            if 'CTC' in opt.Prediction:
                preds_recon = ocrModel(gt_image_tensors,
                                       text_gt,
                                       is_train=True)
                preds_size = torch.IntTensor([preds_recon.size(1)] *
                                             opt.batch_size)
                preds_recon_softmax = preds_recon.log_softmax(2).permute(
                    1, 0, 2)
                ocrCost = ocrCriterion(preds_recon_softmax, text_gt,
                                       preds_size, length_gt)
            else:
                print("Not implemented error")
                sys.exit()

            ocrModel.zero_grad()
            ocrCost.backward()
            # torch.nn.utils.clip_grad_norm_(ocrModel.parameters(), opt.grad_clip)  # gradient clipping with 5 (Default)
            ocr_optimizer.step()
            loss_avg_ocr_sup.add(ocrCost)
        else:
            loss_avg_ocr_sup.add(torch.tensor(0.0))

        # [Word Generator] update
        # image_input_tensors, _, labels, _ = iter(train_loader).next()
        labels_z_c = iter(text_loader).next()

        # image_input_tensors = image_input_tensors.to(device)
        # gt_image_tensors = image_input_tensors[:opt.batch_size]
        # real_image_tensors = image_input_tensors[opt.batch_size:]

        # labels_gt = labels[:opt.batch_size]

        requires_grad(cEncoder, True)
        requires_grad(genModel, True)
        requires_grad(disEncModel, False)
        requires_grad(ocrModel, False)

        text_z_c, length_z_c = converter.encode(
            labels_z_c, batch_max_length=opt.batch_max_length)

        z_c_code = cEncoder(text_z_c)
        style = mixing_noise(z_c_code, opt.batch_size, opt.latent, opt.mixing,
                             device)

        if opt.zAlone:
            #to validate orig style gan results
            newstyle = []
            newstyle.append(style[0][:, :opt.latent])
            if len(style) > 1:
                newstyle.append(style[1][:, :opt.latent])
            style = newstyle

        fake_img, _ = genModel(style, input_is_latent=opt.input_latent)

        fake_pred = disEncModel(fake_img)
        disGenCost = g_nonsaturating_loss(fake_pred)

        if opt.zAlone:
            ocrCost = torch.tensor(0.0)
        else:
            #Compute OCR prediction (Reconstruction of content)
            # text_for_pred = torch.LongTensor(opt.batch_size, opt.batch_max_length + 1).fill_(0).to(device)
            # length_for_pred = torch.IntTensor([opt.batch_max_length] * opt.batch_size).to(device)

            if 'CTC' in opt.Prediction:
                preds_recon = ocrModel(fake_img, text_z_c, is_train=False)
                preds_size = torch.IntTensor([preds_recon.size(1)] *
                                             opt.batch_size)
                preds_recon_softmax = preds_recon.log_softmax(2).permute(
                    1, 0, 2)
                ocrCost = ocrCriterion(preds_recon_softmax, text_z_c,
                                       preds_size, length_z_c)
            else:
                print("Not implemented error")
                sys.exit()

        gen_enc_cost = disGenCost + opt.ocrWeight * ocrCost
        loss_avg_gen.add(disGenCost)
        loss_avg_ocr_unsup.add(opt.ocrWeight * ocrCost)

        genModel.zero_grad()
        cEncoder.zero_grad()

        gen_enc_cost = disGenCost + opt.ocrWeight * ocrCost
        grad_fake_OCR = torch.autograd.grad(ocrCost,
                                            fake_img,
                                            retain_graph=True)[0]
        loss_grad_fake_OCR = 10**6 * torch.mean(grad_fake_OCR**2)
        grad_fake_adv = torch.autograd.grad(disGenCost,
                                            fake_img,
                                            retain_graph=True)[0]
        loss_grad_fake_adv = 10**6 * torch.mean(grad_fake_adv**2)

        if opt.grad_balance:
            gen_enc_cost.backward(retain_graph=True)
            grad_fake_OCR = torch.autograd.grad(ocrCost,
                                                fake_img,
                                                create_graph=True,
                                                retain_graph=True)[0]
            grad_fake_adv = torch.autograd.grad(disGenCost,
                                                fake_img,
                                                create_graph=True,
                                                retain_graph=True)[0]
            a = opt.ocrWeight * torch.div(torch.std(grad_fake_adv),
                                          epsilon + torch.std(grad_fake_OCR))
            if a is None:
                print(ocrCost, disGenCost, torch.std(grad_fake_adv),
                      torch.std(grad_fake_OCR))
            if a > 1000 or a < 0.0001:
                print(a)

            ocrCost = a.detach() * ocrCost
            gen_enc_cost = disGenCost + ocrCost
            gen_enc_cost.backward(retain_graph=True)
            grad_fake_OCR = torch.autograd.grad(ocrCost,
                                                fake_img,
                                                create_graph=False,
                                                retain_graph=True)[0]
            grad_fake_adv = torch.autograd.grad(disGenCost,
                                                fake_img,
                                                create_graph=False,
                                                retain_graph=True)[0]
            loss_grad_fake_OCR = 10**6 * torch.mean(grad_fake_OCR**2)
            loss_grad_fake_adv = 10**6 * torch.mean(grad_fake_adv**2)
            with torch.no_grad():
                gen_enc_cost.backward()
        else:
            gen_enc_cost.backward()

        loss_avg_gen.add(disGenCost)
        loss_avg_ocr_unsup.add(opt.ocrWeight * ocrCost)

        optimizer.step()

        g_regularize = cntr % opt.g_reg_every == 0

        if g_regularize:
            path_batch_size = max(1, opt.batch_size // opt.path_batch_shrink)
            # image_input_tensors, _, labels, _ = iter(train_loader).next()
            labels_z_c = iter(text_loader).next()

            # image_input_tensors = image_input_tensors.to(device)
            # gt_image_tensors = image_input_tensors[:path_batch_size]

            # labels_gt = labels[:path_batch_size]

            text_z_c, length_z_c = converter.encode(
                labels_z_c[:path_batch_size],
                batch_max_length=opt.batch_max_length)
            # text_gt, length_gt = converter.encode(labels_gt, batch_max_length=opt.batch_max_length)

            z_c_code = cEncoder(text_z_c)
            style = mixing_noise(z_c_code, path_batch_size, opt.latent,
                                 opt.mixing, device)

            if opt.zAlone:
                #to validate orig style gan results
                newstyle = []
                newstyle.append(style[0][:, :opt.latent])
                if len(style) > 1:
                    newstyle.append(style[1][:, :opt.latent])
                style = newstyle

            fake_img, grad = genModel(style,
                                      return_latents=True,
                                      g_path_regularize=True,
                                      mean_path_length=mean_path_length)

            decay = 0.01
            path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))

            mean_path_length_orig = mean_path_length + decay * (
                path_lengths.mean() - mean_path_length)
            path_loss = (path_lengths - mean_path_length_orig).pow(2).mean()
            mean_path_length = mean_path_length_orig.detach().item()

            genModel.zero_grad()
            cEncoder.zero_grad()
            weighted_path_loss = opt.path_regularize * opt.g_reg_every * path_loss

            if opt.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            optimizer.step()

            # mean_path_length_avg = (
            #     reduce_sum(mean_path_length).item() / get_world_size()
            # )
            #commented above for multi-gpu , non-distributed setting
            mean_path_length_avg = mean_path_length

        accumulate(g_ema, genModel, accum)

        log_avg_path_loss_val.add(path_loss)
        log_avg_mean_path_length_avg.add(torch.tensor(mean_path_length_avg))
        log_ada_aug_p.add(torch.tensor(ada_aug_p))

        if get_rank() == 0:
            if wandb and opt.wandb:
                wandb.log({
                    "Generator": g_loss_val,
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1": r1_val,
                    "Path Length Regularization": path_loss_val,
                    "Mean Path Length": mean_path_length,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                    "Path Length": path_length_val,
                })

        # validation part
        if (
                iteration + 1
        ) % opt.valInterval == 0 or iteration == 0:  # To see training progress, we also conduct validation when 'iteration == 0'

            #generate paired content with similar style
            labels_z_c_1 = iter(text_loader).next()
            labels_z_c_2 = iter(text_loader).next()

            text_z_c_1, length_z_c_1 = converter.encode(
                labels_z_c_1, batch_max_length=opt.batch_max_length)
            text_z_c_2, length_z_c_2 = converter.encode(
                labels_z_c_2, batch_max_length=opt.batch_max_length)

            z_c_code_1 = cEncoder(text_z_c_1)
            z_c_code_2 = cEncoder(text_z_c_2)

            style_c1_s1 = mixing_noise(z_c_code_1, opt.batch_size, opt.latent,
                                       opt.mixing, device)
            style_c2_s1 = []
            style_c2_s1.append(
                torch.cat((style_c1_s1[0][:, :opt.latent], z_c_code_2), dim=1))
            if len(style_c1_s1) > 1:
                style_c2_s1.append(
                    torch.cat((style_c1_s1[1][:, :opt.latent], z_c_code_2),
                              dim=1))

            style_c1_s2 = mixing_noise(z_c_code_1, opt.batch_size, opt.latent,
                                       opt.mixing, device)

            if opt.zAlone:
                #to validate orig style gan results
                newstyle = []
                newstyle.append(style_c1_s1[0][:, :opt.latent])
                if len(style_c1_s1) > 1:
                    newstyle.append(style_c1_s1[1][:, :opt.latent])
                style_c1_s1 = newstyle
                style_c2_s1 = newstyle
                style_c1_s2 = newstyle

            fake_img_c1_s1, _ = g_ema(style_c1_s1,
                                      input_is_latent=opt.input_latent)
            fake_img_c2_s1, _ = g_ema(style_c2_s1,
                                      input_is_latent=opt.input_latent)
            fake_img_c1_s2, _ = g_ema(style_c1_s2,
                                      input_is_latent=opt.input_latent)

            if not opt.zAlone:
                #Run OCR prediction
                if 'CTC' in opt.Prediction:
                    preds = ocrModel(fake_img_c1_s1,
                                     text_z_c_1,
                                     is_train=False)
                    preds_size = torch.IntTensor([preds.size(1)] *
                                                 opt.batch_size)
                    _, preds_index = preds.max(2)
                    preds_str_fake_img_c1_s1 = converter.decode(
                        preds_index.data, preds_size.data)

                    preds = ocrModel(fake_img_c2_s1,
                                     text_z_c_2,
                                     is_train=False)
                    preds_size = torch.IntTensor([preds.size(1)] *
                                                 opt.batch_size)
                    _, preds_index = preds.max(2)
                    preds_str_fake_img_c2_s1 = converter.decode(
                        preds_index.data, preds_size.data)

                    preds = ocrModel(fake_img_c1_s2,
                                     text_z_c_1,
                                     is_train=False)
                    preds_size = torch.IntTensor([preds.size(1)] *
                                                 opt.batch_size)
                    _, preds_index = preds.max(2)
                    preds_str_fake_img_c1_s2 = converter.decode(
                        preds_index.data, preds_size.data)

                    preds = ocrModel(gt_image_tensors, text_gt, is_train=False)
                    preds_size = torch.IntTensor([preds.size(1)] *
                                                 gt_image_tensors.shape[0])
                    _, preds_index = preds.max(2)
                    preds_str_gt = converter.decode(preds_index.data,
                                                    preds_size.data)

                else:
                    print("Not implemented error")
                    sys.exit()
            else:
                preds_str_fake_img_c1_s1 = [':None:'] * fake_img_c1_s1.shape[0]
                preds_str_gt = [':None:'] * fake_img_c1_s1.shape[0]

            os.makedirs(os.path.join(opt.trainDir, str(iteration)),
                        exist_ok=True)
            for trImgCntr in range(opt.batch_size):
                try:
                    save_image(
                        tensor2im(fake_img_c1_s1[trImgCntr].detach()),
                        os.path.join(
                            opt.trainDir, str(iteration),
                            str(trImgCntr) + '_c1_s1_' +
                            labels_z_c_1[trImgCntr] + '_ocr:' +
                            preds_str_fake_img_c1_s1[trImgCntr] + '.png'))
                    if not opt.zAlone:
                        save_image(
                            tensor2im(fake_img_c2_s1[trImgCntr].detach()),
                            os.path.join(
                                opt.trainDir, str(iteration),
                                str(trImgCntr) + '_c2_s1_' +
                                labels_z_c_2[trImgCntr] + '_ocr:' +
                                preds_str_fake_img_c2_s1[trImgCntr] + '.png'))
                        save_image(
                            tensor2im(fake_img_c1_s2[trImgCntr].detach()),
                            os.path.join(
                                opt.trainDir, str(iteration),
                                str(trImgCntr) + '_c1_s2_' +
                                labels_z_c_1[trImgCntr] + '_ocr:' +
                                preds_str_fake_img_c1_s2[trImgCntr] + '.png'))
                        if trImgCntr < gt_image_tensors.shape[0]:
                            save_image(
                                tensor2im(
                                    gt_image_tensors[trImgCntr].detach()),
                                os.path.join(
                                    opt.trainDir, str(iteration),
                                    str(trImgCntr) + '_gt_act:' +
                                    labels_gt[trImgCntr] + '_ocr:' +
                                    preds_str_gt[trImgCntr] + '.png'))
                except:
                    print('Warning while saving training image')

            elapsed_time = time.time() - start_time
            # for log

            with open(os.path.join(opt.exp_dir, opt.exp_name, 'log_train.txt'),
                      'a') as log:

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}]  \
                    Train Dis loss: {loss_avg_dis.val():0.5f}, Train Gen loss: {loss_avg_gen.val():0.5f},\
                    Train UnSup OCR loss: {loss_avg_ocr_unsup.val():0.5f}, Train Sup OCR loss: {loss_avg_ocr_sup.val():0.5f}, \
                    Train R1-val loss: {log_r1_val.val():0.5f}, Train avg-path-loss: {log_avg_path_loss_val.val():0.5f}, \
                    Train mean-path-length loss: {log_avg_mean_path_length_avg.val():0.5f}, Train ada-aug-p: {log_ada_aug_p.val():0.5f}, \
                    Elapsed_time: {elapsed_time:0.5f}'

                #plotting
                lib.plot.plot(os.path.join(opt.plotDir, 'Train-Dis-Loss'),
                              loss_avg_dis.val().item())
                lib.plot.plot(os.path.join(opt.plotDir, 'Train-Gen-Loss'),
                              loss_avg_gen.val().item())
                lib.plot.plot(
                    os.path.join(opt.plotDir, 'Train-UnSup-OCR-Loss'),
                    loss_avg_ocr_unsup.val().item())
                lib.plot.plot(os.path.join(opt.plotDir, 'Train-Sup-OCR-Loss'),
                              loss_avg_ocr_sup.val().item())
                lib.plot.plot(os.path.join(opt.plotDir, 'Train-r1_val'),
                              log_r1_val.val().item())
                lib.plot.plot(os.path.join(opt.plotDir, 'Train-path_loss_val'),
                              log_avg_path_loss_val.val().item())
                lib.plot.plot(
                    os.path.join(opt.plotDir, 'Train-mean_path_length_avg'),
                    log_avg_mean_path_length_avg.val().item())
                lib.plot.plot(os.path.join(opt.plotDir, 'Train-ada_aug_p'),
                              log_ada_aug_p.val().item())

                print(loss_log)

                loss_avg_dis.reset()
                loss_avg_gen.reset()
                loss_avg_ocr_unsup.reset()
                loss_avg_ocr_sup.reset()
                log_r1_val.reset()
                log_avg_path_loss_val.reset()
                log_avg_mean_path_length_avg.reset()
                log_ada_aug_p.reset()

            lib.plot.flush()

        lib.plot.tick()

        # save model per 1e+5 iter.
        if (iteration) % 1e+4 == 0:
            torch.save(
                {
                    'cEncoder': cEncoder.state_dict(),
                    'genModel': genModel.state_dict(),
                    'g_ema': g_ema.state_dict(),
                    'ocrModel': ocrModel.state_dict(),
                    'disEncModel': disEncModel.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'ocr_optimizer': ocr_optimizer.state_dict(),
                    'dis_optimizer': dis_optimizer.state_dict()
                },
                os.path.join(opt.exp_dir, opt.exp_name,
                             'iter_' + str(iteration + 1) + '_synth.pth'))

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
        cntr += 1
def train(opt):
    plotDir = os.path.join(opt.exp_dir, opt.exp_name, 'plots')
    if not os.path.exists(plotDir):
        os.makedirs(plotDir)

    lib.print_model_settings(locals().copy())
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')

    log = open(os.path.join(opt.exp_dir, opt.exp_name, 'log_dataset.txt'), 'a')
    AlignCollate_valid = AlignPairCollate(imgH=opt.imgH,
                                          imgW=opt.imgW,
                                          keep_ratio_with_pad=opt.PAD)

    train_dataset, train_dataset_log = hierarchical_dataset(
        root=opt.train_data, opt=opt)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    log.write(train_dataset_log)
    print('-' * 80)

    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        False,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()

    if 'Attn' in opt.Prediction:
        converter = AttnLabelConverter(opt.character)
    else:
        converter = CTCLabelConverter(opt.character)

    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3

    ocrModel = ModelV1(opt)
    styleModel = StyleTensorEncoder(input_dim=opt.input_channel)
    genModel = AdaIN_Tensor_WordGenerator(opt)
    disModel = MsImageDisV2(opt)

    if opt.contentLoss == 'vis' or opt.contentLoss == 'seq':
        ocrCriterion = torch.nn.L1Loss()
    else:
        if 'CTC' in opt.Prediction:
            ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
        else:
            ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
                device)  # ignore [GO] token = ignore index 0

    vggRecCriterion = torch.nn.L1Loss()
    vggModel = VGGPerceptualLossModel(models.vgg19(pretrained=True),
                                      vggRecCriterion)

    print('model input parameters', opt.imgH, opt.imgW, opt.input_channel,
          opt.output_channel, opt.hidden_size, opt.num_class,
          opt.batch_max_length)

    #  weight initialization
    for currModel in [styleModel, genModel, disModel]:
        for name, param in currModel.named_parameters():
            if 'localization_fc2' in name:
                print(f'Skip {name} as it is already initialized')
                continue
            try:
                if 'bias' in name:
                    init.constant_(param, 0.0)
                elif 'weight' in name:
                    init.kaiming_normal_(param)
            except Exception as e:  # for batchnorm.
                if 'weight' in name:
                    param.data.fill_(1)
                continue

    styleModel = torch.nn.DataParallel(styleModel).to(device)
    styleModel.train()

    genModel = torch.nn.DataParallel(genModel).to(device)
    genModel.train()

    disModel = torch.nn.DataParallel(disModel).to(device)
    disModel.train()

    vggModel = torch.nn.DataParallel(vggModel).to(device)
    vggModel.eval()

    ocrModel = torch.nn.DataParallel(ocrModel).to(device)
    ocrModel.module.Transformation.eval()
    ocrModel.module.FeatureExtraction.eval()
    ocrModel.module.AdaptiveAvgPool.eval()
    # ocrModel.module.SequenceModeling.eval()
    ocrModel.module.Prediction.eval()

    if opt.modelFolderFlag:
        if len(
                glob.glob(
                    os.path.join(opt.exp_dir, opt.exp_name,
                                 "iter_*_synth.pth"))) > 0:
            opt.saved_synth_model = glob.glob(
                os.path.join(opt.exp_dir, opt.exp_name,
                             "iter_*_synth.pth"))[-1]

    if opt.saved_ocr_model != '' and opt.saved_ocr_model != 'None':
        print(f'loading pretrained ocr model from {opt.saved_ocr_model}')
        checkpoint = torch.load(opt.saved_ocr_model)
        ocrModel.load_state_dict(checkpoint)

    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        print(f'loading pretrained synth model from {opt.saved_synth_model}')
        checkpoint = torch.load(opt.saved_synth_model)

        styleModel.load_state_dict(checkpoint['styleModel'])
        genModel.load_state_dict(checkpoint['genModel'])
        disModel.load_state_dict(checkpoint['disModel'])

    if opt.imgReconLoss == 'l1':
        recCriterion = torch.nn.L1Loss()
    elif opt.imgReconLoss == 'ssim':
        recCriterion = ssim
    elif opt.imgReconLoss == 'ms-ssim':
        recCriterion = msssim

    if opt.styleLoss == 'l1':
        styleRecCriterion = torch.nn.L1Loss()
    elif opt.styleLoss == 'triplet':
        styleRecCriterion = torch.nn.TripletMarginLoss(
            margin=opt.tripletMargin, p=1)
    #for validation; check only positive pairs
    styleTestRecCriterion = torch.nn.L1Loss()

    # loss averager
    loss_avg = Averager()
    loss_avg_dis = Averager()
    loss_avg_gen = Averager()
    loss_avg_imgRecon = Averager()
    loss_avg_vgg_per = Averager()
    loss_avg_vgg_sty = Averager()
    loss_avg_ocr = Averager()

    ##---------------------------------------##
    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, styleModel.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    for p in filter(lambda p: p.requires_grad, genModel.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable style and generator params num : ', sum(params_num))

    # setup optimizer
    if opt.optim == 'adam':
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, opt.beta2),
                               weight_decay=opt.weight_decay)
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps,
                                   weight_decay=opt.weight_decay)
    print("SynthOptimizer:")
    print(optimizer)

    #filter parameters for Dis training
    dis_filtered_parameters = []
    dis_params_num = []
    for p in filter(lambda p: p.requires_grad, disModel.parameters()):
        dis_filtered_parameters.append(p)
        dis_params_num.append(np.prod(p.size()))
    print('Dis Trainable params num : ', sum(dis_params_num))

    # setup optimizer
    if opt.optim == 'adam':
        dis_optimizer = optim.Adam(dis_filtered_parameters,
                                   lr=opt.lr,
                                   betas=(opt.beta1, opt.beta2),
                                   weight_decay=opt.weight_decay)
    else:
        dis_optimizer = optim.Adadelta(dis_filtered_parameters,
                                       lr=opt.lr,
                                       rho=opt.rho,
                                       eps=opt.eps,
                                       weight_decay=opt.weight_decay)
    print("DisOptimizer:")
    print(dis_optimizer)
    ##---------------------------------------##
    """ final options """
    with open(os.path.join(opt.exp_dir, opt.exp_name, 'opt.txt'),
              'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0

    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        try:
            start_iter = int(
                opt.saved_synth_model.split('_')[-2].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    #get schedulers
    scheduler = get_scheduler(optimizer, opt)
    dis_scheduler = get_scheduler(dis_optimizer, opt)

    start_time = time.time()
    iteration = start_iter
    cntr = 0

    while (True):
        # train part
        if opt.lr_policy != "None":
            scheduler.step()
            dis_scheduler.step()

        image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(
            train_loader).next()

        cntr += 1

        image_input_tensors = image_input_tensors.to(device)
        image_gt_tensors = image_gt_tensors.to(device)
        batch_size = image_input_tensors.size(0)

        text_1, length_1 = converter.encode(
            labels_1, batch_max_length=opt.batch_max_length)
        text_2, length_2 = converter.encode(
            labels_2, batch_max_length=opt.batch_max_length)

        #forward pass from style and word generator
        style = styleModel(image_input_tensors)

        images_recon_2 = genModel(style, text_2)

        #Domain discriminator: Dis update
        disModel.zero_grad()
        disCost = opt.disWeight * (disModel.module.calc_dis_loss(
            torch.cat((images_recon_2.detach(), image_input_tensors), dim=1),
            torch.cat((image_gt_tensors, image_input_tensors), dim=1)))

        disCost.backward()
        dis_optimizer.step()
        loss_avg_dis.add(disCost)

        # #[Style Encoder] + [Word Generator] update
        #Adversarial loss
        disGenCost = disModel.module.calc_gen_loss(
            torch.cat((images_recon_2, image_input_tensors), dim=1))

        #Input reconstruction loss
        recCost = recCriterion(images_recon_2, image_gt_tensors)

        #vgg loss
        vggPerCost, vggStyleCost = vggModel(image_gt_tensors, images_recon_2)

        #ocr loss
        text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                         1).fill_(0).to(device)
        length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                          batch_size).to(device)
        if opt.contentLoss == 'vis' or opt.contentLoss == 'seq':
            preds_recon = ocrModel(images_recon_2,
                                   text_for_pred,
                                   is_train=False,
                                   returnFeat=opt.contentLoss)
            preds_gt = ocrModel(image_gt_tensors,
                                text_for_pred,
                                is_train=False,
                                returnFeat=opt.contentLoss)
            ocrCost = ocrCriterion(preds_recon, preds_gt)
        else:
            if 'CTC' in opt.Prediction:
                preds_recon = ocrModel(images_recon_2,
                                       text_for_pred,
                                       is_train=False)
                preds_o = preds_recon.deepcopy()[:, :text_1.shape[1] - 1, :]
                preds_size = torch.IntTensor([preds_recon.size(1)] *
                                             batch_size)
                preds_recon = preds_recon.log_softmax(2).permute(1, 0, 2)
                ocrCost = ocrCriterion(preds_recon, text_2, preds_size,
                                       length_2)

                #predict ocr recognition on generated images
                preds_o_size = torch.IntTensor([preds_o.size(1)] * batch_size)
                _, preds_o_index = preds_o.max(2)
                labels_o_ocr = converter.decode(preds_o_index.data,
                                                preds_o_size.data)

                #predict ocr recognition on gt style images
                preds_s = ocrModel(image_input_tensors,
                                   text_for_pred,
                                   is_train=False)
                preds_s = preds_s[:, :text_1.shape[1] - 1, :]
                preds_s_size = torch.IntTensor([preds_s.size(1)] * batch_size)
                _, preds_s_index = preds_s.max(2)
                labels_s_ocr = converter.decode(preds_s_index.data,
                                                preds_s_size.data)

                #predict ocr recognition on gt stylecontent images
                preds_sc = ocrModel(image_input_tensors,
                                    text_for_pred,
                                    is_train=False)
                preds_sc = preds_sc[:, :text_2.shape[1] - 1, :]
                preds_sc_size = torch.IntTensor([preds_sc.size(1)] *
                                                batch_size)
                _, preds_sc_index = preds_sc.max(2)
                labels_sc_ocr = converter.decode(preds_sc_index.data,
                                                 preds_sc_size.data)

            else:
                preds_recon = ocrModel(
                    images_recon_2, text_for_pred[:, :-1],
                    is_train=False)  # align with Attention.forward
                target_2 = text_2[:, 1:]  # without [GO] Symbol
                ocrCost = ocrCriterion(
                    preds_recon.view(-1, preds_recon.shape[-1]),
                    target_2.contiguous().view(-1))

                #predict ocr recognition on generated images
                _, preds_o_index = preds_recon.max(2)
                labels_o_ocr = converter.decode(preds_o_index, length_for_pred)
                for idx, pred in enumerate(labels_o_ocr):
                    pred_EOS = pred.find('[s]')
                    labels_o_ocr[
                        idx] = pred[:
                                    pred_EOS]  # prune after "end of sentence" token ([s])

                #predict ocr recognition on gt style images
                preds_s = ocrModel(image_input_tensors,
                                   text_for_pred,
                                   is_train=False)
                _, preds_s_index = preds_s.max(2)
                labels_s_ocr = converter.decode(preds_s_index, length_for_pred)
                for idx, pred in enumerate(labels_s_ocr):
                    pred_EOS = pred.find('[s]')
                    labels_s_ocr[
                        idx] = pred[:
                                    pred_EOS]  # prune after "end of sentence" token ([s])

                #predict ocr recognition on gt stylecontent images
                preds_sc = ocrModel(image_gt_tensors,
                                    text_for_pred,
                                    is_train=False)
                _, preds_sc_index = preds_sc.max(2)
                labels_sc_ocr = converter.decode(preds_sc_index,
                                                 length_for_pred)
                for idx, pred in enumerate(labels_sc_ocr):
                    pred_EOS = pred.find('[s]')
                    labels_sc_ocr[
                        idx] = pred[:
                                    pred_EOS]  # prune after "end of sentence" token ([s])

        cost = opt.reconWeight * recCost + opt.disWeight * disGenCost + opt.vggPerWeight * vggPerCost + opt.vggStyWeight * vggStyleCost + opt.ocrWeight * ocrCost

        styleModel.zero_grad()
        genModel.zero_grad()
        disModel.zero_grad()
        vggModel.zero_grad()
        ocrModel.zero_grad()

        cost.backward()
        optimizer.step()
        loss_avg.add(cost)

        #Individual losses
        loss_avg_gen.add(opt.disWeight * disGenCost)
        loss_avg_imgRecon.add(opt.reconWeight * recCost)
        loss_avg_vgg_per.add(opt.vggPerWeight * vggPerCost)
        loss_avg_vgg_sty.add(opt.vggStyWeight * vggStyleCost)
        loss_avg_ocr.add(opt.ocrWeight * ocrCost)

        # validation part
        if (
                iteration + 1
        ) % opt.valInterval == 0 or iteration == 0:  # To see training progress, we also conduct validation when 'iteration == 0'

            #Save training images
            os.makedirs(os.path.join(opt.exp_dir, opt.exp_name, 'trainImages',
                                     str(iteration)),
                        exist_ok=True)
            for trImgCntr in range(batch_size):
                try:
                    if opt.contentLoss == 'vis' or opt.contentLoss == 'seq':
                        save_image(
                            tensor2im(image_input_tensors[trImgCntr].detach()),
                            os.path.join(
                                opt.exp_dir, opt.exp_name, 'trainImages',
                                str(iteration),
                                str(trImgCntr) + '_sInput_' +
                                labels_1[trImgCntr] + '.png'))
                        save_image(
                            tensor2im(image_gt_tensors[trImgCntr].detach()),
                            os.path.join(
                                opt.exp_dir, opt.exp_name, 'trainImages',
                                str(iteration),
                                str(trImgCntr) + '_csGT_' +
                                labels_2[trImgCntr] + '.png'))
                        save_image(
                            tensor2im(images_recon_2[trImgCntr].detach()),
                            os.path.join(
                                opt.exp_dir, opt.exp_name, 'trainImages',
                                str(iteration),
                                str(trImgCntr) + '_csRecon_' +
                                labels_2[trImgCntr] + '.png'))
                    else:
                        save_image(
                            tensor2im(image_input_tensors[trImgCntr].detach()),
                            os.path.join(
                                opt.exp_dir, opt.exp_name, 'trainImages',
                                str(iteration),
                                str(trImgCntr) + '_sInput_' +
                                labels_1[trImgCntr] + '_' +
                                labels_s_ocr[trImgCntr] + '.png'))
                        save_image(
                            tensor2im(image_gt_tensors[trImgCntr].detach()),
                            os.path.join(
                                opt.exp_dir, opt.exp_name, 'trainImages',
                                str(iteration),
                                str(trImgCntr) + '_csGT_' +
                                labels_2[trImgCntr] + '_' +
                                labels_sc_ocr[trImgCntr] + '.png'))
                        save_image(
                            tensor2im(images_recon_2[trImgCntr].detach()),
                            os.path.join(
                                opt.exp_dir, opt.exp_name, 'trainImages',
                                str(iteration),
                                str(trImgCntr) + '_csRecon_' +
                                labels_2[trImgCntr] + '_' +
                                labels_o_ocr[trImgCntr] + '.png'))
                except:
                    print('Warning while saving training image')

            elapsed_time = time.time() - start_time
            # for log

            with open(os.path.join(opt.exp_dir, opt.exp_name, 'log_train.txt'),
                      'a') as log:
                styleModel.eval()
                genModel.eval()
                disModel.eval()

                with torch.no_grad():
                    valid_loss, infer_time, length_of_data = validation_synth_v4(
                        iteration, styleModel, genModel, vggModel, ocrModel,
                        disModel, recCriterion, ocrCriterion, valid_loader,
                        converter, opt)

                styleModel.train()
                genModel.train()
                disModel.train()

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}] Train Synth loss: {loss_avg.val():0.5f}, \
                    Train Dis loss: {loss_avg_dis.val():0.5f}, Train Gen loss: {loss_avg_gen.val():0.5f},\
                    Train ImgRecon loss: {loss_avg_imgRecon.val():0.5f}, Train VGG-Per loss: {loss_avg_vgg_per.val():0.5f},\
                    Train VGG-Sty loss: {loss_avg_vgg_sty.val():0.5f}, Train OCR loss: {loss_avg_ocr.val():0.5f}, Valid Synth loss: {valid_loss[0]:0.5f}, \
                    Valid Dis loss: {valid_loss[1]:0.5f}, Valid Gen loss: {valid_loss[2]:0.5f}, \
                    Valid ImgRecon loss: {valid_loss[3]:0.5f}, Valid VGG-Per loss: {valid_loss[4]:0.5f}, \
                    Valid VGG-Sty loss: {valid_loss[5]:0.5f}, Valid OCR loss: {valid_loss[6]:0.5f}, Elapsed_time: {elapsed_time:0.5f}'

                #plotting
                lib.plot.plot(os.path.join(plotDir, 'Train-Synth-Loss'),
                              loss_avg.val().item())
                lib.plot.plot(os.path.join(plotDir, 'Train-Dis-Loss'),
                              loss_avg_dis.val().item())

                lib.plot.plot(os.path.join(plotDir, 'Train-Gen-Loss'),
                              loss_avg_gen.val().item())
                lib.plot.plot(os.path.join(plotDir, 'Train-ImgRecon1-Loss'),
                              loss_avg_imgRecon.val().item())
                lib.plot.plot(os.path.join(plotDir, 'Train-VGG-Per-Loss'),
                              loss_avg_vgg_per.val().item())
                lib.plot.plot(os.path.join(plotDir, 'Train-VGG-Sty-Loss'),
                              loss_avg_vgg_sty.val().item())
                lib.plot.plot(os.path.join(plotDir, 'Train-OCR-Loss'),
                              loss_avg_ocr.val().item())

                lib.plot.plot(os.path.join(plotDir, 'Valid-Synth-Loss'),
                              valid_loss[0].item())
                lib.plot.plot(os.path.join(plotDir, 'Valid-Dis-Loss'),
                              valid_loss[1].item())

                lib.plot.plot(os.path.join(plotDir, 'Valid-Gen-Loss'),
                              valid_loss[2].item())
                lib.plot.plot(os.path.join(plotDir, 'Valid-ImgRecon1-Loss'),
                              valid_loss[3].item())
                lib.plot.plot(os.path.join(plotDir, 'Valid-VGG-Per-Loss'),
                              valid_loss[4].item())
                lib.plot.plot(os.path.join(plotDir, 'Valid-VGG-Sty-Loss'),
                              valid_loss[5].item())
                lib.plot.plot(os.path.join(plotDir, 'Valid-OCR-Loss'),
                              valid_loss[6].item())

                print(loss_log)

                loss_avg.reset()
                loss_avg_dis.reset()

                loss_avg_gen.reset()
                loss_avg_imgRecon.reset()
                loss_avg_vgg_per.reset()
                loss_avg_vgg_sty.reset()
                loss_avg_ocr.reset()

            lib.plot.flush()

        lib.plot.tick()

        # save model per 1e+5 iter.
        if (iteration) % 1e+4 == 0:
            torch.save(
                {
                    'styleModel': styleModel.state_dict(),
                    'genModel': genModel.state_dict(),
                    'disModel': disModel.state_dict()
                },
                os.path.join(opt.exp_dir, opt.exp_name,
                             'iter_' + str(iteration + 1) + '_synth.pth'))

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1