Esempio n. 1
0
    def validate(self, epoch, val_loader):
        # 训练时,每隔一定的epoch且epoch>15进行验证
        if not self.args["val"]:
            if epoch % self.args["output"]['save_frequency'] != 0 and epoch > 0:
                return
            if 0 < epoch < 15:
                return
        self.net.eval()

        loss_avg = Averager()
        lls_avg = Averager()
        em_avg = Averager()
        bs = self.args["train"]["batch_size"]
        ###
        data_filler = BatchFiller(bs)
        target_filler = BatchFiller(bs)
        pred_filler = BatchFiller()
        full_target_filler = BatchFiller()

        val_results = []

        with torch.no_grad():
            pred_idx = 0
            use_cuda = torch.cuda.is_available()
            total_sample = len(val_loader)
            for sample_idx, sample in enumerate(
                    val_loader):  ### 不打乱顺序,根据batchsize来输入

                data, target, name = sample[0], sample[1], sample[2]

                if len(sample) > 3:
                    sequence_idx = sample[3]
                else:
                    args = None
                    sequence_idx = 0

                data = data.squeeze(0)
                target = target.squeeze(0)
                data_pieces, split_position = self.splitcomb.split(data)
                target_pieces, split_position = self.splitcomb.split(target)

                data_filler.enqueue(
                    sample=list(data_pieces),
                    name=[name[0] for _ in range(data_pieces.shape[0])],
                    shape=[
                        split_position for _ in range(data_pieces.shape[0])
                    ],
                    sequence_idx=[
                        sequence_idx for _ in range(data_pieces.shape[0])
                    ],
                )
                target_filler.enqueue(
                    sample=list(target_pieces),
                    name=[name[0] for _ in range(data_pieces.shape[0])],
                    shape=[
                        split_position for _ in range(data_pieces.shape[0])
                    ],
                    sequence_idx=[
                        sequence_idx for _ in range(data_pieces.shape[0])
                    ],
                )

                full_target_filler.enqueue(sample=[target],
                                           name=name,
                                           shape=[None],
                                           sequence_idx=[sequence_idx])

                if sample_idx + 1 == total_sample:
                    pad_num = max(bs - len(data_filler.sample_queue) % bs, 1)
                    data_filler.enqueue(
                        sample=[
                            np.zeros_like(data_pieces[0, :])
                            for _ in range(pad_num)
                        ],
                        name=["padding" for _ in range(pad_num)],
                        shape=[None for _ in range(pad_num)],
                        sequence_idx=[sequence_idx for _ in range(pad_num)],
                    )
                    target_filler.enqueue(
                        sample=[(np.zeros_like(target_pieces[0, :]))
                                for _ in range(pad_num)],
                        name=["padding" for _ in range(pad_num)],
                        shape=[None for _ in range(pad_num)],
                        sequence_idx=[sequence_idx for _ in range(pad_num)],
                    )
                    full_target_filler.enqueue(
                        sample=[np.zeros_like(target)],
                        name=["padding"],
                        shape=[None],
                        sequence_idx=[sequence_idx],
                    )
                while data_filler.isFull(mode="batch"):

                    data_batch, name, shape, sequence_idx = data_filler.dequeue(
                        mode="batch")
                    target_batch, name, shape, sequence_idx = target_filler.dequeue(
                        mode="batch")

                    if use_cuda:
                        data_batch = torch.from_numpy(
                            np.stack(data_batch, axis=0)).cuda()
                        target_batch = torch.from_numpy(
                            np.stack(target_batch, axis=0)).cuda()

                    total_loss, loss_list, logits = self.warp(
                        data_batch, target_batch, True, sequence_idx)

                    pred_filler.enqueue(sample=list(logits),
                                        name=name,
                                        shape=shape,
                                        sequence_idx=sequence_idx)
                    loss_avg.update(total_loss.mean().detach().cpu().numpy())
                    loss_list = tuple([l.cpu().numpy() for l in loss_list])
                    lls_avg.update(loss_list)

                while pred_filler.isFull(mode="sample"):

                    pred_full, _, shape, sequence_idx = pred_filler.dequeue(
                        mode="sample")
                    target_full, name, _, sequence_idx = full_target_filler.dequeue(
                        mode="sample")
                    pred_full = self.splitcomb.combine(pred_full, shape[0])
                    pred_full = choose_top1_connected_component(
                        model_pred=pred_full, choose_top1=self.choose_top1)
                    # pred_full = dynamic_choose_topk_vessel_connected_component(model_pred=pred_full, choose_topk=self.choose_topk)

                    em_list = []
                    if self.emlist is not None:
                        for em_fun in self.emlist:
                            em_list.extend(em_fun(pred_full, target_full[0]))
                        em_list = tuple(
                            [l.cpu().squeeze().numpy() for l in em_list])
                        em_avg.update(em_list)

                    curr_case_nii_path = os.path.join(self.testdir,
                                                      name[0]) + "_pred.nii.gz"
                    os.makedirs(os.path.dirname(curr_case_nii_path),
                                exist_ok=True)
                    self.writer.SetFileName(curr_case_nii_path)
                    self.writer.Execute(
                        sitk.GetImageFromArray(
                            (pred_full.cpu().squeeze(0).numpy()).astype(
                                np.uint8)))

                    curr_case_npy_path = os.path.join(self.save_dir, 'val_out',
                                                      '{}.npy'.format(name[0]))
                    os.makedirs(os.path.dirname(curr_case_npy_path),
                                exist_ok=True)
                    np.save(
                        curr_case_npy_path,
                        pred_full.cpu().squeeze(0).numpy().astype(np.uint8))

                    info = "Finish validation %d out of %d, name %s, " % (
                        pred_idx + 1,
                        len(val_loader),
                        name[0],
                    )
                    pred_idx += 1
                    for lid, l in enumerate(em_list):
                        info += "em %d: %.4f, " % (lid, l)
                    print(info)
                    val_results.append(info)

        if not self.args["val"]:
            if epoch % self.args["output"]["save_frequency"] == 0:
                self.ioer.save_file(self.net, epoch, self.args, 0)
            else:
                return

        if self.emlist is not None:
            em_list = em_avg.val()
        self.__writeLossLog(
            "Val",
            epoch,
            meanloss=loss_avg.val(),
            loss_list=lls_avg.val(),
            em_list=em_list,
        )

        with open(os.path.join(self.save_dir, '{}_val.txt'.format(epoch)),
                  'a') as f_out:
            f_out.write('\n'.join(val_results) + '\n\n')
Esempio n. 2
0
def train(opt):
    plotDir = os.path.join(opt.exp_dir, opt.exp_name, 'plots')
    if not os.path.exists(plotDir):
        os.makedirs(plotDir)

    lib.print_model_settings(locals().copy())
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')

    log = open(os.path.join(opt.exp_dir, opt.exp_name, 'log_dataset.txt'), 'a')
    AlignCollate_valid = AlignPairCollate(imgH=opt.imgH,
                                          imgW=opt.imgW,
                                          keep_ratio_with_pad=opt.PAD)

    train_dataset, train_dataset_log = hierarchical_dataset(
        root=opt.train_data, opt=opt)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    log.write(train_dataset_log)
    print('-' * 80)

    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        False,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()

    converter = CTCLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3

    styleModel = StyleTensorEncoder(input_dim=opt.input_channel)
    genModel = AdaIN_Tensor_WordGenerator(opt)
    disModel = MsImageDisV2(opt)

    vggRecCriterion = torch.nn.L1Loss()
    vggModel = VGGPerceptualLossModel(models.vgg19(pretrained=True),
                                      vggRecCriterion)

    print('model input parameters', opt.imgH, opt.imgW, opt.input_channel,
          opt.output_channel, opt.hidden_size, opt.num_class,
          opt.batch_max_length)

    #  weight initialization
    for currModel in [styleModel, genModel, disModel]:
        for name, param in currModel.named_parameters():
            if 'localization_fc2' in name:
                print(f'Skip {name} as it is already initialized')
                continue
            try:
                if 'bias' in name:
                    init.constant_(param, 0.0)
                elif 'weight' in name:
                    init.kaiming_normal_(param)
            except Exception as e:  # for batchnorm.
                if 'weight' in name:
                    param.data.fill_(1)
                continue

    styleModel = torch.nn.DataParallel(styleModel).to(device)
    styleModel.train()

    genModel = torch.nn.DataParallel(genModel).to(device)
    genModel.train()

    disModel = torch.nn.DataParallel(disModel).to(device)
    disModel.train()

    vggModel = torch.nn.DataParallel(vggModel).to(device)
    vggModel.eval()

    if opt.modelFolderFlag:
        if len(
                glob.glob(
                    os.path.join(opt.exp_dir, opt.exp_name,
                                 "iter_*_synth.pth"))) > 0:
            opt.saved_synth_model = glob.glob(
                os.path.join(opt.exp_dir, opt.exp_name,
                             "iter_*_synth.pth"))[-1]

    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        print(f'loading pretrained synth model from {opt.saved_synth_model}')
        checkpoint = torch.load(opt.saved_synth_model)

        styleModel.load_state_dict(checkpoint['styleModel'])
        genModel.load_state_dict(checkpoint['genModel'])
        disModel.load_state_dict(checkpoint['disModel'])

    if opt.imgReconLoss == 'l1':
        recCriterion = torch.nn.L1Loss()
    elif opt.imgReconLoss == 'ssim':
        recCriterion = ssim
    elif opt.imgReconLoss == 'ms-ssim':
        recCriterion = msssim

    if opt.styleLoss == 'l1':
        styleRecCriterion = torch.nn.L1Loss()
    elif opt.styleLoss == 'triplet':
        styleRecCriterion = torch.nn.TripletMarginLoss(
            margin=opt.tripletMargin, p=1)
    #for validation; check only positive pairs
    styleTestRecCriterion = torch.nn.L1Loss()

    # loss averager
    loss_avg = Averager()
    loss_avg_dis = Averager()
    loss_avg_gen = Averager()
    loss_avg_imgRecon = Averager()
    loss_avg_vgg_per = Averager()
    loss_avg_vgg_sty = Averager()

    ##---------------------------------------##
    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, styleModel.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    for p in filter(lambda p: p.requires_grad, genModel.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable style and generator params num : ', sum(params_num))

    # setup optimizer
    if opt.optim == 'adam':
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, opt.beta2),
                               weight_decay=opt.weight_decay)
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps,
                                   weight_decay=opt.weight_decay)
    print("SynthOptimizer:")
    print(optimizer)

    #filter parameters for Dis training
    dis_filtered_parameters = []
    dis_params_num = []
    for p in filter(lambda p: p.requires_grad, disModel.parameters()):
        dis_filtered_parameters.append(p)
        dis_params_num.append(np.prod(p.size()))
    print('Dis Trainable params num : ', sum(dis_params_num))

    # setup optimizer
    if opt.optim == 'adam':
        dis_optimizer = optim.Adam(dis_filtered_parameters,
                                   lr=opt.lr,
                                   betas=(opt.beta1, opt.beta2),
                                   weight_decay=opt.weight_decay)
    else:
        dis_optimizer = optim.Adadelta(dis_filtered_parameters,
                                       lr=opt.lr,
                                       rho=opt.rho,
                                       eps=opt.eps,
                                       weight_decay=opt.weight_decay)
    print("DisOptimizer:")
    print(dis_optimizer)
    ##---------------------------------------##
    """ final options """
    with open(os.path.join(opt.exp_dir, opt.exp_name, 'opt.txt'),
              'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0

    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        try:
            start_iter = int(
                opt.saved_synth_model.split('_')[-2].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    #get schedulers
    scheduler = get_scheduler(optimizer, opt)
    dis_scheduler = get_scheduler(dis_optimizer, opt)

    start_time = time.time()
    iteration = start_iter
    cntr = 0

    while (True):
        # train part
        if opt.lr_policy != "None":
            scheduler.step()
            dis_scheduler.step()

        image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(
            train_loader).next()

        cntr += 1

        image_input_tensors = image_input_tensors.to(device)
        image_gt_tensors = image_gt_tensors.to(device)
        batch_size = image_input_tensors.size(0)
        text_2, length_2 = converter.encode(
            labels_2, batch_max_length=opt.batch_max_length)

        #forward pass from style and word generator
        style = styleModel(image_input_tensors)

        images_recon_2 = genModel(style, text_2)

        #Domain discriminator: Dis update
        disModel.zero_grad()
        disCost = opt.disWeight * (disModel.module.calc_dis_loss(
            torch.cat((images_recon_2.detach(), image_input_tensors), dim=1),
            torch.cat((image_gt_tensors, image_input_tensors), dim=1)))

        disCost.backward()
        dis_optimizer.step()
        loss_avg_dis.add(disCost)

        # #[Style Encoder] + [Word Generator] update
        #Adversarial loss
        disGenCost = disModel.module.calc_gen_loss(
            torch.cat((images_recon_2, image_input_tensors), dim=1))

        #Input reconstruction loss
        recCost = recCriterion(images_recon_2, image_gt_tensors)

        #vgg loss
        vggPerCost, vggStyleCost = vggModel(image_gt_tensors, images_recon_2)

        cost = opt.reconWeight * recCost + opt.disWeight * disGenCost + opt.vggPerWeight * vggPerCost + opt.vggStyWeight * vggStyleCost

        styleModel.zero_grad()
        genModel.zero_grad()
        disModel.zero_grad()
        vggModel.zero_grad()

        cost.backward()
        optimizer.step()
        loss_avg.add(cost)

        #Individual losses
        loss_avg_gen.add(opt.disWeight * disGenCost)
        loss_avg_imgRecon.add(opt.reconWeight * recCost)
        loss_avg_vgg_per.add(opt.vggPerWeight * vggPerCost)
        loss_avg_vgg_sty.add(opt.vggStyWeight * vggStyleCost)

        # validation part
        if (
                iteration + 1
        ) % opt.valInterval == 0 or iteration == 0:  # To see training progress, we also conduct validation when 'iteration == 0'

            #Save training images
            os.makedirs(os.path.join(opt.exp_dir, opt.exp_name, 'trainImages',
                                     str(iteration)),
                        exist_ok=True)
            for trImgCntr in range(batch_size):
                try:
                    save_image(
                        tensor2im(image_input_tensors[trImgCntr].detach()),
                        os.path.join(
                            opt.exp_dir, opt.exp_name, 'trainImages',
                            str(iteration),
                            str(trImgCntr) + '_sInput_' + labels_1[trImgCntr] +
                            '.png'))
                    save_image(
                        tensor2im(image_gt_tensors[trImgCntr].detach()),
                        os.path.join(
                            opt.exp_dir, opt.exp_name, 'trainImages',
                            str(iteration),
                            str(trImgCntr) + '_csGT_' + labels_2[trImgCntr] +
                            '.png'))
                    save_image(
                        tensor2im(images_recon_2[trImgCntr].detach()),
                        os.path.join(
                            opt.exp_dir, opt.exp_name, 'trainImages',
                            str(iteration),
                            str(trImgCntr) + '_csRecon_' +
                            labels_2[trImgCntr] + '.png'))
                except:
                    print('Warning while saving training image')

            elapsed_time = time.time() - start_time
            # for log

            with open(os.path.join(opt.exp_dir, opt.exp_name, 'log_train.txt'),
                      'a') as log:
                styleModel.eval()
                genModel.eval()
                disModel.eval()

                with torch.no_grad():
                    valid_loss, infer_time, length_of_data = validation_synth_v3(
                        iteration, styleModel, genModel, vggModel, disModel,
                        recCriterion, valid_loader, converter, opt)

                styleModel.train()
                genModel.train()
                disModel.train()

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}] Train Synth loss: {loss_avg.val():0.5f}, \
                    Train Dis loss: {loss_avg_dis.val():0.5f}, Train Gen loss: {loss_avg_gen.val():0.5f},\
                    Train ImgRecon loss: {loss_avg_imgRecon.val():0.5f}, Train VGG-Per loss: {loss_avg_vgg_per.val():0.5f},\
                    Train VGG-Sty loss: {loss_avg_vgg_sty.val():0.5f}, Valid Synth loss: {valid_loss[1]:0.5f}, \
                    Valid Dis loss: {valid_loss[2]:0.5f}, Elapsed_time: {elapsed_time:0.5f}'

                #plotting
                lib.plot.plot(os.path.join(plotDir, 'Train-Synth-Loss'),
                              loss_avg.val().item())
                lib.plot.plot(os.path.join(plotDir, 'Train-Dis-Loss'),
                              loss_avg_dis.val().item())

                lib.plot.plot(os.path.join(plotDir, 'Train-Gen-Loss'),
                              loss_avg_gen.val().item())
                lib.plot.plot(os.path.join(plotDir, 'Train-ImgRecon1-Loss'),
                              loss_avg_imgRecon.val().item())
                lib.plot.plot(os.path.join(plotDir, 'Train-VGG-Per-Loss'),
                              loss_avg_vgg_per.val().item())
                lib.plot.plot(os.path.join(plotDir, 'Train-VGG-Sty-Loss'),
                              loss_avg_vgg_sty.val().item())

                lib.plot.plot(os.path.join(plotDir, 'Valid-Synth-Loss'),
                              valid_loss[0].item())
                lib.plot.plot(os.path.join(plotDir, 'Valid-Dis-Loss'),
                              valid_loss[1].item())

                lib.plot.plot(os.path.join(plotDir, 'Valid-Gen-Loss'),
                              valid_loss[2].item())
                lib.plot.plot(os.path.join(plotDir, 'Valid-ImgRecon1-Loss'),
                              valid_loss[3].item())
                lib.plot.plot(os.path.join(plotDir, 'Valid-VGG-Per-Loss'),
                              valid_loss[4].item())
                lib.plot.plot(os.path.join(plotDir, 'Valid-VGG-Sty-Loss'),
                              valid_loss[5].item())

                print(loss_log)

                loss_avg.reset()
                loss_avg_dis.reset()

                loss_avg_gen.reset()
                loss_avg_imgRecon.reset()
                loss_avg_vgg_per.reset()
                loss_avg_vgg_sty.reset()

            lib.plot.flush()

        lib.plot.tick()

        # save model per 1e+5 iter.
        if (iteration) % 1e+4 == 0:
            torch.save(
                {
                    'styleModel': styleModel.state_dict(),
                    'genModel': genModel.state_dict(),
                    'disModel': disModel.state_dict()
                },
                os.path.join(opt.exp_dir, opt.exp_name,
                             'iter_' + str(iteration + 1) + '_synth.pth'))

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
Esempio n. 3
0
def train():
    """ dataset preparation """
    train_dataset_lmdb = LmdbDataset(cfg.lmdb_trainset_dir_name)
    val_dataset_lmdb = LmdbDataset(cfg.lmdb_valset_dir_name)

    train_loader = torch.utils.data.DataLoader(
        train_dataset_lmdb, batch_size=cfg.batch_size,
        collate_fn=data_collate,
        shuffle=True,
        num_workers=int(cfg.workers),
        pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(
        val_dataset_lmdb, batch_size=cfg.batch_size,
        collate_fn=data_collate,
        shuffle=True,  # 'True' to check training progress with validation function.
        num_workers=int(cfg.workers),
        pin_memory=True)

    # --------------------训练过程---------------------------------
    model = advancedEAST()
    if int(cfg.train_task_id[-3:]) != 256:
        id_num = cfg.train_task_id[-3:]
        idx_dic = {'384': 256, '512': 384, '640': 512, '736': 640}
        model.load_state_dict(torch.load('./saved_model/3T{}_best_loss.pth'.format(idx_dic[id_num])))
    elif os.path.exists('./saved_model/3T{}_best_loss.pth'.format(cfg.train_task_id)):
        model.load_state_dict(torch.load('./saved_model/3T{}_best_loss.pth'.format(cfg.train_task_id)))

    model = model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.decay)
    loss_func = quad_loss

    train_Loss_list = []
    val_Loss_list = []

    '''start training'''
    start_iter = 0
    if cfg.saved_model != '':
        try:
            start_iter = int(cfg.saved_model.split('_')[-1].split('.')[0])
            print('continue to train, start_iter: {}'.format(start_iter))
        except Exception as e:
            print(e)
            pass

    start_time = time.time()
    best_mF1_score = 0
    i = start_iter
    step_num = 0
    start_time = time.time()
    loss_avg = Averager()
    val_loss_avg = Averager()
    eval_p_r_f = eval_pre_rec_f1()

    while(True):
        model.train()
        # train part
        # training-----------------------------
        for image_tensors, labels, gt_xy_list in train_loader:
            step_num += 1
            batch_x = image_tensors.to(device).float()
            batch_y = labels.to(device).float()  # float64转float32

            out = model(batch_x)
            loss = loss_func(batch_y, out)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_avg.add(loss)
            train_Loss_list.append(loss_avg.val())
            if i == 5 or (i + 1) % 10 == 0:
                eval_p_r_f.add(out, gt_xy_list)  # 非常耗时!!!

        # save model per 100 epochs.
        if (i + 1) % 1e+2 == 0:
            torch.save(model.state_dict(), './saved_models/{}/{}_iter_{}.pth'.format(cfg.train_task_id, cfg.train_task_id, step_num+1))

        print('Epoch:[{}/{}] Training Loss: {:.3f}'.format(i + 1, cfg.epoch_num, train_Loss_list[-1].item()))
        loss_avg.reset()

        if i == 5 or (i + 1) % 10 == 0:
            mPre, mRec, mF1_score = eval_p_r_f.val()
            print('Training meanPrecision:{:.2f}% meanRecall:{:.2f}% meanF1-score:{:.2f}%'.format(mPre, mRec, mF1_score))
            eval_p_r_f.reset()

        # evaluation--------------------------------
        if (i + 1) % cfg.valInterval == 0:
            elapsed_time = time.time() - start_time
            print('Elapsed time:{}s'.format(round(elapsed_time)))
            model.eval()
            for image_tensors, labels, gt_xy_list in valid_loader:
                batch_x = image_tensors.to(device)
                batch_y = labels.to(device).float()  # float64转float32

                out = model(batch_x)
                loss = loss_func(batch_y, out)

                val_loss_avg.add(loss)
                val_Loss_list.append(val_loss_avg.val())
                eval_p_r_f.add(out, gt_xy_list)

            mPre, mRec, mF1_score = eval_p_r_f.val()
            print('validation meanPrecision:{:.2f}% meanRecall:{:.2f}% meanF1-score:{:.2f}%'.format(mPre, mRec, mF1_score))
            eval_p_r_f.reset()

            if mF1_score > best_mF1_score:  # 记录最佳模型
                best_mF1_score = mF1_score
                torch.save(model.state_dict(), './saved_models/{}/{}_best_mF1_score_{:.3f}.pth'.format(cfg.train_task_id, cfg.train_task_id, mF1_score))
                torch.save(model.state_dict(), './saved_model/{}_best_mF1_score.pth'.format(cfg.train_task_id))

            print('Validation loss:{:.3f}'.format(val_loss_avg.val().item()))
            val_loss_avg.reset()

        if i == cfg.epoch_num:
            torch.save(model.state_dict(), './saved_models/{}/{}_iter_{}.pth'.format(cfg.train_task_id, cfg.train_task_id, i+1))
            print('End the training')
            break
        i += 1

    sys.exit()
Esempio n. 4
0
    def train(self, model, dataloader, train_loader, valid_loader):
        if not os.path.exists(os.path.join(self.save_path, self.args.name)):
            os.makedirs(os.path.join(self.save_path, self.args.name))
        else:
            if not self.args.delete:
                raise SyntaxError(
                    f'{os.path.join(self.save_path, self.args.name)} is exist.'
                )

        save_folder = os.path.join(self.save_path, self.args.name)

        classes = model.classes
        model = torch.nn.DataParallel(model).to(self.device)
        model.to(self.device)

        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(self.device)

        filtered_parameters = []
        params_num = []
        for p in filter(lambda p: p.requires_grad, model.parameters()):
            filtered_parameters.append(p)
            params_num.append(np.prod(p.size()))
        print('Trainable params num : ', sum(params_num))

        # optimizer & scheduler
        optimizer = optim.Adam(filtered_parameters,
                               lr=self.lr,
                               betas=(0.9, 0.999))
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                         eta_min=1e-5,
                                                         T_max=self.epochs)

        best_acc = 0
        taken_time = time()
        for epoch in range(self.epochs):
            t_loss_avg = Averager()
            v_loss_avg = Averager()
            t_calc = ScoreCalc()
            v_calc = ScoreCalc()
            model.train()

            word_target = None
            word_preds = None

            with tqdm(train_loader, unit="batch") as tepoch:
                for batch, batch_sampler in enumerate(tepoch):
                    tepoch.set_description(
                        f"Epoch {epoch+1} / Batch {batch+1}")

                    img = batch_sampler[0].to(self.device)
                    text = batch_sampler[1][0].to(self.device)
                    length = batch_sampler[1][1]

                    if (self.args.choose_model == "ASTER"):
                        preds = model(img, text[:, :-1],
                                      max(length).cpu().numpy())
                    else:
                        preds = model(img, text[:, :-1],
                                      max(length).cpu().numpy())

                    target = text[:, 1:]
                    t_cost = criterion(
                        preds.contiguous().view(-1, preds.shape[-1]),
                        target.contiguous().view(-1))

                    model.zero_grad()
                    t_cost.backward()

                    torch.nn.utils.clip_grad_norm_(
                        model.parameters(),
                        5)  # gradient clipping with 5 (Default)

                    optimizer.step()
                    scheduler.step()

                    t_loss_avg.add(t_cost)
                    self.batch_size = len(text)
                    pred_max = torch.argmax(
                        F.softmax(preds, dim=2).view(self.batch_size, -1,
                                                     classes), 2)

                    t_calc.add(
                        target,
                        F.softmax(preds, dim=2).view(self.batch_size, -1,
                                                     classes), length)
                    #print(dataloader.dataset.converter.decode(target,length),dataloader.dataset.converter.decode(pred_max,length))
                    if batch % (300) == 0:
                        word_target = dataloader.dataset.converter.decode(
                            target, length)[0]
                        word_preds = dataloader.dataset.converter.decode(
                            pred_max, length)[0]
                    tepoch.set_postfix(loss=t_loss_avg.val().item(),acc=t_calc.val().item(),\
                                          preds=word_preds,target=word_target)

                    del batch_sampler, pred_max, img, text, length

            model.eval()
            with tqdm(valid_loader, unit="batch") as vepoch:
                for batch, batch_sampler in enumerate(vepoch):
                    vepoch.set_description(
                        f"Epoch {epoch+1} / Batch {batch+1}")
                    with torch.no_grad():
                        img = batch_sampler[0].to(self.device)
                        text = batch_sampler[1][0].to(self.device)
                        length = batch_sampler[1][1].to(self.device)

                        preds = model(img, text[:, :-1],
                                      max(length).cpu().numpy())
                        target = text[:, 1:]
                        v_cost = criterion(
                            preds.contiguous().view(-1, preds.shape[-1]),
                            target.contiguous().view(-1))

                        torch.nn.utils.clip_grad_norm_(
                            model.parameters(),
                            5)  # gradient clipping with 5 (Default)

                        v_loss_avg.add(v_cost)
                        batch_size = len(text)
                        pred_max = torch.argmax(
                            F.softmax(preds,
                                      dim=2).view(batch_size, -1, classes), 2)

                        v_calc.add(
                            target,
                            F.softmax(preds,
                                      dim=2).view(batch_size, -1, classes),
                            length)

                        vepoch.set_postfix(loss=v_loss_avg.val().item(),
                                           acc=v_calc.val().item())
                        del batch_sampler, v_cost, pred_max, img, text, length

            if not os.path.exists(os.path.join(save_folder, self.args.name)):
                os.makedirs(os.path.join(save_folder, self.args.name))
            #save_plt(xs,os.path.join(save_folder,name),0,epoch)
            log = dict()
            log['epoch'] = epoch + 1
            log['t_loss'] = t_loss_avg.val().item()
            log['t_acc'] = t_calc.val().item()

            log['v_loss'] = v_loss_avg.val().item()
            log['v_acc'] = v_calc.val().item()
            log['time'] = time() - taken_time
            with open(os.path.join(save_folder, f'{self.args.name}.log'),
                      'a') as f:
                json.dump(log, f, indent=2)

            best_loss = t_loss_avg.val().item()
            if best_acc < v_calc.val().item():
                best_acc = v_calc.val().item()
                torch.save(model.state_dict(),
                           os.path.join(save_folder, f'{self.args.name}.pth'))
def train(opt):
    lib.print_model_settings(locals().copy())

    # train_transform =  transforms.Compose([
    #     # transforms.RandomResizedCrop(input_size),
    #     transforms.Resize((opt.imgH, opt.imgW)),
    #     # transforms.RandomHorizontalFlip(),
    #     transforms.ToTensor(),
    #     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    # ])

    # val_transform = transforms.Compose([
    #     transforms.Resize((opt.imgH, opt.imgW)),
    #     # transforms.CenterCrop(input_size),
    #     transforms.ToTensor(),
    #     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    # ])

    AlignFontCollateObj = AlignFontCollate(imgH=opt.imgH,
                                           imgW=opt.imgW,
                                           keep_ratio_with_pad=opt.PAD)
    train_dataset = fontTextDataset(imgDir=opt.train_img_dir,
                                    annFile=opt.train_ann_file,
                                    transform=None,
                                    numClasses=opt.numClasses)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=opt.batch_size,
        shuffle=
        False,  # 'True' to check training progress with validation function.
        sampler=data_sampler(train_dataset,
                             shuffle=True,
                             distributed=opt.distributed),
        num_workers=int(opt.workers),
        collate_fn=AlignFontCollateObj,
        pin_memory=True,
        drop_last=False)
    # numClasses = len(train_dataset.Idx2F)
    numClasses = np.unique(train_dataset.fontIdx).size

    train_loader = sample_data(train_loader)
    print('-' * 80)
    numTrainSamples = len(train_dataset)

    # valid_dataset = LmdbStyleDataset(root=opt.valid_data, opt=opt)
    valid_dataset = fontTextDataset(imgDir=opt.train_img_dir,
                                    annFile=opt.val_ann_file,
                                    transform=None,
                                    F2Idx=train_dataset.F2Idx,
                                    Idx2F=train_dataset.Idx2F,
                                    numClasses=opt.numClasses)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        False,  # 'True' to check training progress with validation function.
        sampler=data_sampler(valid_dataset,
                             shuffle=False,
                             distributed=opt.distributed),
        num_workers=int(opt.workers),
        collate_fn=AlignFontCollateObj,
        pin_memory=True,
        drop_last=False)
    numTestSamples = len(valid_dataset)

    print('numClasses', numClasses)
    print('numTrainSamples', numTrainSamples)
    print('numTestSamples', numTestSamples)

    vggFontModel = VGGFontModel(models.vgg19(pretrained=opt.preTrained),
                                numClasses).to(device)
    for name, param in vggFontModel.classifier.named_parameters():
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            print('Exception in weight init' + name)
            if 'weight' in name:
                param.data.fill_(1)
            continue

    if opt.optim == "sgd":
        print('SGD optimizer')
        optimizer = optim.SGD(vggFontModel.parameters(),
                              lr=opt.lr,
                              momentum=0.9)
    elif opt.optim == "adam":
        print('Adam optimizer')
        optimizer = optim.Adam(vggFontModel.parameters(), lr=opt.lr)
    #get schedulers
    scheduler = get_scheduler(optimizer, opt)

    criterion = torch.nn.CrossEntropyLoss()

    if opt.modelFolderFlag:
        if len(
                glob.glob(
                    os.path.join(opt.exp_dir, opt.exp_name,
                                 "iter_*_vggFont.pth"))) > 0:
            opt.saved_font_model = glob.glob(
                os.path.join(opt.exp_dir, opt.exp_name,
                             "iter_*_vggFont.pth"))[-1]

    ## Loading pre-trained files
    if opt.saved_font_model != '' and opt.saved_font_model != 'None':
        print(f'loading pretrained synth model from {opt.saved_font_model}')
        checkpoint = torch.load(opt.saved_font_model,
                                map_location=lambda storage, loc: storage)

        vggFontModel.load_state_dict(checkpoint['vggFontModel'])
        optimizer.load_state_dict(checkpoint["optimizer"])
        scheduler.load_state_dict(checkpoint["scheduler"])

    # print('Model Initialization')
    #
    # print('Loaded checkpoint')

    if opt.distributed:
        vggFontModel = torch.nn.parallel.DistributedDataParallel(
            vggFontModel,
            device_ids=[opt.local_rank],
            output_device=opt.local_rank,
            broadcast_buffers=False,
            find_unused_parameters=True)
        vggFontModel.train()

    # print('Loaded distributed')

    if opt.distributed:
        vggFontModel_module = vggFontModel.module
    else:
        vggFontModel_module = vggFontModel

    # print('Loading module')

    # loss averager
    loss_train = Averager()
    loss_val = Averager()
    train_acc = Averager()
    val_acc = Averager()
    train_acc_5 = Averager()
    val_acc_5 = Averager()
    """ final options """
    with open(os.path.join(opt.exp_dir, opt.exp_name, 'opt.txt'),
              'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0

    if opt.saved_font_model != '' and opt.saved_font_model != 'None':
        try:
            start_iter = int(opt.saved_font_model.split('_')[-2].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    iteration = start_iter

    cntr = 0
    # trainCorrect=0
    # tCntr=0
    while (True):
        # print(cntr)
        # train part

        start_time = time.time()
        if not opt.testFlag:

            image_input_tensors, labels_gt = next(train_loader)
            image_input_tensors = image_input_tensors.to(device)
            labels_gt = labels_gt.view(-1).to(device)
            preds = vggFontModel(image_input_tensors)

            loss = criterion(preds, labels_gt)

            vggFontModel.zero_grad()
            loss.backward()
            optimizer.step()

            # _, preds_max = preds.max(dim=1)
            # trainCorrect += (preds_max == labels_gt).sum()
            # tCntr+=preds_max.shape[0]

            acc1, acc5 = getNumCorrect(preds,
                                       labels_gt,
                                       topk=(1, min(numClasses, 5)))
            train_acc.addScalar(acc1, preds.shape[0])
            train_acc_5.addScalar(acc5, preds.shape[0])

            loss_train.add(loss)

            if opt.lr_policy != "None":
                scheduler.step()

        # print
        if get_rank() == 0:
            if (
                    iteration + 1
            ) % opt.valInterval == 0 or iteration == 0 or opt.testFlag:  # To see training progress, we also conduct validation when 'iteration == 0'
                #validation
                # iCntr=torch.tensor(0.0).to(device)
                # valCorrect=torch.tensor(0.0).to(device)
                vggFontModel.eval()
                print('Inside val', iteration)

                for vCntr, (image_input_tensors,
                            labels_gt) in enumerate(valid_loader):
                    # print('vCntr--',vCntr)
                    if opt.debugFlag and vCntr > 2:
                        break

                    with torch.no_grad():
                        image_input_tensors = image_input_tensors.to(device)
                        labels_gt = labels_gt.view(-1).to(device)

                        preds = vggFontModel(image_input_tensors)
                        loss = criterion(preds, labels_gt)
                        loss_val.add(loss)

                        if opt.testFlag:
                            # pdb.set_trace()
                            _, preds_max = preds.max(dim=1)
                            for vCntr in range(preds_max.shape[0]):
                                print('Actual=',
                                      train_dataset.Idx2F[labels_gt[vCntr]],
                                      'Predicted=',
                                      train_dataset.Idx2F[preds_max[vCntr]])

                        acc1, acc5 = getNumCorrect(preds,
                                                   labels_gt,
                                                   topk=(1, min(numClasses,
                                                                5)))
                        val_acc.addScalar(acc1, preds.shape[0])
                        val_acc_5.addScalar(acc5, preds.shape[0])

                vggFontModel.train()
                elapsed_time = time.time() - start_time

                #DO HERE
                with open(
                        os.path.join(opt.exp_dir, opt.exp_name,
                                     'log_train.txt'), 'a') as log:
                    # print('COUNT-------',val_acc_5.n_count)
                    # training loss and validation loss
                    loss_log = f'[{iteration+1}/{opt.num_iter}]  \
                        Train loss: {loss_train.val():0.5f}, Val loss: {loss_val.val():0.5f}, \
                        Train Top-1 Acc: {train_acc.val()*100:0.5f}, Train Top-5 Acc: {train_acc_5.val()*100:0.5f}, \
                        Val Top-1 Acc: {val_acc.val()*100:0.5f}, Val Top-5 Acc: {val_acc_5.val()*100:0.5f}, \
                        Elapsed_time: {elapsed_time:0.5f}'

                    #plotting
                    lib.plot.plot(os.path.join(opt.plotDir, 'Train-Loss'),
                                  loss_train.val().item())
                    lib.plot.plot(os.path.join(opt.plotDir, 'Val-Loss'),
                                  loss_val.val().item())
                    lib.plot.plot(os.path.join(opt.plotDir, 'Train-Top-1-Acc'),
                                  train_acc.val() * 100)
                    lib.plot.plot(os.path.join(opt.plotDir, 'Train-Top-5-Acc'),
                                  train_acc_5.val() * 100)
                    lib.plot.plot(os.path.join(opt.plotDir, 'Val-Top-1-Acc'),
                                  val_acc.val() * 100)
                    lib.plot.plot(os.path.join(opt.plotDir, 'Val-Top-5-Acc'),
                                  val_acc_5.val() * 100)

                    print(loss_log)
                    log.write(loss_log + "\n")

                    loss_train.reset()
                    loss_val.reset()
                    train_acc.reset()
                    val_acc.reset()
                    train_acc_5.reset()
                    val_acc_5.reset()
                    # trainCorrect=0
                    # tCntr=0

                lib.plot.flush()

            # save model per 30000 iter.
            if (iteration) % 15000 == 0:
                torch.save(
                    {
                        'vggFontModel': vggFontModel_module.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict()
                    },
                    os.path.join(opt.exp_dir, opt.exp_name, 'iter_' +
                                 str(iteration + 1) + '_vggFont.pth'))

            lib.plot.tick()

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
        cntr += 1
Esempio n. 6
0
def validation(model, criterion, evaluation_loader, converter, opt):
    """ validation or evaluation """
    n_correct = 0
    norm_ED = 0
    length_of_data = 0
    infer_time = 0
    valid_loss_avg = Averager()

    for i, (image_tensors, labels) in enumerate(evaluation_loader):
        batch_size = image_tensors.size(0)
        length_of_data = length_of_data + batch_size
        image = image_tensors.to(device)
        # For max length prediction
        length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                          batch_size).to(device)
        text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                         1).fill_(0).to(device)

        text_for_loss, length_for_loss = converter.encode(
            labels, batch_max_length=opt.batch_max_length)

        start_time = time.time()
        if 'CTC' in opt.Prediction:
            preds = model(image, text_for_pred).log_softmax(2)
            forward_time = time.time() - start_time

            # Calculate evaluation loss for CTC deocder.
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            # permute 'preds' to use CTCloss format
            cost = criterion(preds.permute(1, 0, 2), text_for_loss, preds_size,
                             length_for_loss)

            # Select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            preds_index = preds_index.view(-1)
            preds_str = converter.decode(preds_index.data, preds_size.data)

        else:
            preds = model(image, text_for_pred, is_train=False)
            forward_time = time.time() - start_time

            preds = preds[:, :text_for_loss.shape[1] - 1, :]
            target = text_for_loss[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.contiguous().view(-1, preds.shape[-1]),
                             target.contiguous().view(-1))

            # select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            preds_str = converter.decode(preds_index, length_for_pred)
            labels = converter.decode(text_for_loss[:, 1:], length_for_loss)

        infer_time += forward_time
        valid_loss_avg.add(cost)

        # calculate accuracy & confidence score
        preds_prob = F.softmax(preds, dim=2)
        preds_max_prob, _ = preds_prob.max(dim=2)
        confidence_score_list = []
        for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob):
            if 'Attn' in opt.Prediction:
                gt = gt[:gt.find('[s]')]
                pred_EOS = pred.find('[s]')
                pred = pred[:
                            pred_EOS]  # prune after "end of sentence" token ([s])
                pred_max_prob = pred_max_prob[:pred_EOS]

            if pred == gt:
                n_correct += 1
            if len(gt) == 0:
                norm_ED += 1
            else:
                norm_ED += edit_distance(pred, gt) / len(gt)

            # calculate confidence score (= multiply of pred_max_prob)
            try:
                confidence_score = pred_max_prob.cumprod(dim=0)[-1]
            except:
                confidence_score = 0  # for empty pred case, when prune after "end of sentence" token ([s])
            confidence_score_list.append(confidence_score)
            # print(pred, gt, pred==gt, confidence_score)

    accuracy = n_correct / float(length_of_data) * 100

    return valid_loss_avg.val(
    ), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data
Esempio n. 7
0
def validation_by_length(model, criterion, evaluation_loader, converter, opt):
    """ validation or evaluation """
    n_correct = defaultdict(float)
    n_norm_ED = defaultdict(float)
    length_of_data = defaultdict(int)
    infer_time = 0
    valid_loss_avg = Averager()

    for i, (image_tensors, labels) in enumerate(evaluation_loader):
        batch_size = image_tensors.size(0)
        # length_of_data = length_of_data + batch_size
        image = image_tensors.to(device)
        # For max length prediction
        length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                          batch_size).to(device)
        text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                         1).fill_(0).to(device)

        text_for_loss, length_for_loss = converter.encode(
            labels, batch_max_length=opt.batch_max_length)

        start_time = time.time()
        if 'CTC' in opt.Prediction:
            preds = model(image, text_for_pred)
            forward_time = time.time() - start_time

            # Calculate evaluation loss for CTC deocder.
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            # permute 'preds' to use CTCloss format
            cost = criterion(
                preds.log_softmax(2).permute(1, 0, 2), text_for_loss,
                preds_size, length_for_loss)

            # Select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            preds_str = converter.decode(preds_index.data, preds_size.data)

        else:
            preds, alphas = model(image, text_for_pred, is_train=False)
            forward_time = time.time() - start_time

            preds = preds[:, :text_for_loss.shape[1] - 1, :]
            target = text_for_loss[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.contiguous().view(-1, preds.shape[-1]),
                             target.contiguous().view(-1))

            # select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            preds_str = converter.decode(preds_index, length_for_pred)
            labels = converter.decode(text_for_loss[:, 1:], length_for_loss)

        infer_time += forward_time
        valid_loss_avg.add(cost)

        # calculate accuracy & confidence score
        preds_prob = F.softmax(preds, dim=2)
        preds_max_prob, _ = preds_prob.max(dim=2)
        confidence_score_list = []
        for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob):
            if 'Attn' in opt.Prediction:
                gt = gt[:gt.find('[s]')]
                pred_EOS = pred.find('[s]')
                pred = pred[:
                            pred_EOS]  # prune after "end of sentence" token ([s])
                pred_max_prob = pred_max_prob[:pred_EOS]

            # To evaluate 'case sensitive model' with alphanumeric and case insensitve setting.
            if opt.sensitive and opt.data_filtering_off:
                pred = pred.lower()
                gt = gt.lower()
                alphanumeric_case_insensitve = '0123456789abcdefghijklmnopqrstuvwxyz'
                out_of_alphanumeric_case_insensitve = f'[^{alphanumeric_case_insensitve}]'
                pred = re.sub(out_of_alphanumeric_case_insensitve, '', pred)
                gt = re.sub(out_of_alphanumeric_case_insensitve, '', gt)

            length_of_data[len(gt)] += 1

            if pred == gt:
                n_correct[len(gt)] += 1
            '''
            (old version) ICDAR2017 DOST Normalized Edit Distance https://rrc.cvc.uab.es/?ch=7&com=tasks
            "For each word we calculate the normalized edit distance to the length of the ground truth transcription."
            if len(gt) == 0:
                norm_ED += 1
            else:
                norm_ED += edit_distance(pred, gt) / len(gt)
            '''

            # ICDAR2019 Normalized Edit Distance
            if len(gt) == 0 or len(pred) == 0:
                n_norm_ED[len(gt)] += 0
            elif len(gt) > len(pred):
                n_norm_ED[len(gt)] += 1 - edit_distance(pred, gt) / len(gt)
            else:
                n_norm_ED[len(gt)] += 1 - edit_distance(pred, gt) / len(pred)

            # calculate confidence score (= multiply of pred_max_prob)
            try:
                confidence_score = pred_max_prob.cumprod(dim=0)[-1]
            except:
                confidence_score = 0  # for empty pred case, when prune after "end of sentence" token ([s])
            # log = open(f'./result/{opt.exp_name}/log_all_evaluation.txt', 'a')
            # log.write(f'pred: {pred}, gt: {gt}, {pred == gt}, prob: {confidence_score:0.4f}\n')
            # log.close()
            confidence_score_list.append(confidence_score)
            # print(pred, gt, pred == gt, confidence_score.item())

    accuracy = defaultdict(float)
    norm_ED = defaultdict(float)
    for k in n_correct.keys():
        accuracy[k] = n_correct[k] / float(length_of_data[k]) * 100
        norm_ED[k] = n_norm_ED[k] / float(length_of_data[k]) * 100

    return valid_loss_avg.val(
    ), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data
Esempio n. 8
0
    def validate(self, epoch, val_loader, save=False):


        """
        测试每张CT,每一个iter为CT中所有的小块
        :param val_loader -> Dataloader(): 所有测试的CT

        """
        startt = time.time()
        self.net.eval()
        # print(vars(self.ioer))
        if not self.args.val:
            if epoch % self.args.output['save_frequency'] == 0:
                self.ioer.save_file(self.net, epoch, self.args, 0)
            if epoch % self.args.output['val_frequency'] != 0:
                return
        
        loss_avg = Averager()
        lls_avg = Averager()
        em_avg = Averager()
        bs =  self.args.train['batch_size']
        if save:
            savedir = os.path.join(self.ioer.save_dir, '%03d' % epoch)
            if not os.path.exists(savedir):
                #shutil.rmtree(savedir)
                os.mkdir(savedir)

        ap_pred_list = []
        ap_gt_list = []
        with torch.no_grad():

            xbatch_filler = Batch_Filler(bs, size = [self.args.prepare['channel_input']] + self.args.prepare['crop_size'], dtype=self.dtype)
            bbox_left = None
            infos_list = []

            for sample_idx, tmp in enumerate(val_loader):
                # zhw: shape of data
                data, zhw, name, fullab = tmp[0]

                infos_list.append([zhw, name, fullab])
                x_left = torch.from_numpy(data).cuda()

                # 在一张CT中不断取bs大小的数据
                while x_left is not None:
                    isFull, xbatch, belong, idxlist, x_left = xbatch_filler.fill(sample_idx, x_left)

                    if len(val_loader) == sample_idx + 1:  # the last sample force execute test operation.
                        isFull = True
                        if len(idxlist) != bs:
                            fill_length = bs - len(idxlist)
                            for i in range(fill_length):
                                idxlist.append(-1)

                    if isFull:
                        data = xbatch
                        logits = self.warp(data, calc_loss=False)
                        logits = self.clipmargin(list(logits))

                        box_batch = MaskableList()
                        thresh_lists = []
                        for i_batch in range(data.shape[0]):
                            # 坐标&置信度
                            box_iter, thresh_list = decode_bbox(logits, thresh=-2, idx=i_batch, config=self.args)
                            box_batch.append(box_iter)
                            thresh_lists.append(thresh_list)
                    

                        for i_idx, idx in enumerate(idxlist):
                            if idx == -1:
                                break
                            zhw, name, fullab = infos_list[i_idx]
                            fullab = torch.from_numpy(fullab).cuda()
                            zhw = torch.from_numpy(zhw)
                            bbox_pieces = box_batch[belong==idx]

                            if bbox_left is not None:
                                bbox_pieces = bbox_left + bbox_pieces
                                bbox_left = None

                            # 由小块还原到原图的坐标
                            # comb_pred: 所有预测结果框 [n, 8] [z1,y1,x1,z2,y2,x2,confidence, cls]
                            comb_pred = val_loader.dataset.split_comb.combine(bbox_pieces, zhw)
                            # print(comb_pred.shape)
                            # print(fullab.shape)
                            # print(comb_pred)
                            # print(comb_pred.shape)
                            # print(fullab)

                            # 统计个数
                            em_list = []
                            if self.emlist is not None:
                                for em_fun in self.emlist:
                                    # 计算所有预测框的hit情况
                                    # fulllab: 所有gt box [n, 10] [z,y,x,dz,dy,dx,cls,1,1,1]
                                    # iou == 0.2
                                    em_result, iou_info = em_fun(comb_pred, fullab)
                                    # print(em_result)
                                    # exit()

                                    if len(comb_pred) > 0:
                                        comb_pred_tmp = comb_pred.cpu().numpy()
                                        # 预测置信度
                                        pred_probs = comb_pred_tmp[:, 6:7]
                                        # z轴区间为box大小
                                        bbox_size = comb_pred_tmp[:, 3:4] - comb_pred_tmp[:, 0:1]
                                        # [prob, diameter, hit-iou, [coords]]
                                        ap_pred_list.append(np.concatenate([pred_probs, bbox_size, iou_info[:, :1], comb_pred_tmp[:, :6]], axis=1))
                                        # print(ap_pred_list[0][1,2])
                                        # exit()
                                    else:
                                        ap_pred_list.append([])
                                    if fullab.shape[0] > 0:
                                        fullab_tmp = fullab.cpu().numpy()
                                        lab_size = fullab_tmp[:, 3:4]
                                        lab_center = fullab_tmp[:, :3]
                                        ap_gt_list.append(np.concatenate([lab_size, lab_center], axis=1))
                                    else:
                                        ap_gt_list.append([])
                                    em_list.append(em_result)
                                em_avg.update(tuple(em_list))

                            info = 'end %d out of %d, name %s, '%(idx, len(val_loader), name)
                            for lid, l in enumerate(em_list):
                                if isinstance(l,dict):
                                    for k,v in l.items():
                                        info += '%s: %.2f, '%(k, v)
                                else:
                                    info += '%d: %.2f, '%(lid, l)
                            threshs = np.array(thresh_lists).mean(axis=0)
                            info += 'thresh: '
                            for level, thresh in enumerate(threshs):
                                info += 'level %d= %.02f, '%(level, thresh)
                            print(info)
                            if save:
                                if isinstance(comb_pred, torch.Tensor):
                                    comb_pred = comb_pred.cpu().numpy()
                                try:
                                    np.save(os.path.join(savedir, name+'.npy'), np.concatenate([comb_pred, iou_info], axis=1))
                                except:
                                    print(name)
                        bbox_left_new = box_batch[(belong)==belong[-1]]
                        bbox_left, infos_list = restart_logit(x_left, bbox_left, bbox_left_new, infos_list)
                        xbatch_filler.restart()

        cPickle.dump(ap_pred_list, open(os.path.join(self.ioer.save_dir, 'ap_pred_list.pkl'), 'wb'))
        cPickle.dump(ap_gt_list, open(os.path.join(self.ioer.save_dir, 'ap_gt_list.pkl'), 'wb'))
        ap_small_bbox_list = []
        ap_big_bbox_list = []
        ap_small_gt_bbox_count = 0
        ap_big_gt_bbox_count = 0
        
        # 二者数量相同,CT个数
        assert len(ap_pred_list) == len(ap_gt_list)
        # 所有ct
        for idx, ap_pred in enumerate(ap_pred_list):
            ap_gt = ap_gt_list[idx]
            
            # 所有box
            for i, ap_pred_x in enumerate(ap_pred):
                # iou_info >= 0
                if ap_pred_x[2] >= 0:
                    bbox_size = ap_gt[int(ap_pred_x[2])][0]
                else:
                    bbox_size = ap_pred_x[1]

                # 小目标, 
                if bbox_size < self.small_size:
                    # [prob, id(ct)_id(lgt), size]
                    ap_small_bbox_list.append([ap_pred_x[0], str(idx) + '_' + str(int(ap_pred_x[2])), bbox_size])
                else:
                    ap_big_bbox_list.append([ap_pred_x[0], str(idx) + '_' + str(int(ap_pred_x[2])), bbox_size])
            
            for i, ap_gt_x in enumerate(ap_gt):
                bbox_size = ap_gt_x[0]
                if bbox_size < self.small_size:
                    ap_small_gt_bbox_count += 1
                else:
                    ap_big_gt_bbox_count += 1

        # import pdb; pdb.set_trace()
        # cal froc
        froc_val = self.froc(bbox_info=ap_big_bbox_list, 
                             gt_count=ap_big_gt_bbox_count, 
                             fps=[0.5, 1, 2, 4, 8], 
                             n_ct=len(val_loader))
        # print('FROC: {}'.format(froc_val))
        self.printf('FROC: ' + str(froc_val))

        rp_list = []
        # do with small & big box
        # for ap_bbox_list, ap_gt_bbox_count in zip([ap_small_bbox_list, ap_big_bbox_list], \
        #                                           [ap_small_gt_bbox_count, ap_big_gt_bbox_count]):
        ap_bbox_list = ap_big_bbox_list
        ap_gt_bbox_count = ap_big_gt_bbox_count
        recall_level = 1
        rp = {}

        # 按照prob排序
        # ap_bbox_list: [prob, id_cls, size]
        ap_bbox_list.sort(key=lambda x: -x[0])
        gt_bbox_hits = []
        pred_bbox_hit_count = 0
        for idx, ap_bbox in enumerate(ap_bbox_list):
            bbox_tag = ap_bbox[1]
            if not bbox_tag.endswith('-1'):
                pred_bbox_hit_count += 1
                # 如果有-1的框,会多记入一个
                if bbox_tag not in gt_bbox_hits:
                    gt_bbox_hits.append(bbox_tag)
            while len(gt_bbox_hits) / ap_gt_bbox_count >= recall_level*0.1 and recall_level <= 10:
                rp[recall_level] = [pred_bbox_hit_count / (idx + 1), ap_bbox[0]]
                recall_level += 1
        rp_list.append(rp)

        if self.emlist is not None:
            em_list = em_avg.val()
        endt = time.time()
        self.writeLossLog('Val', epoch, meanloss = 0, loss_list = [], em_list=em_list, time=(endt-startt)/60)
        # self.printf('small: ' + str(rp_list[0]))
        # self.printf('big: ' + str(rp_list[1]))
        self.printf('big: ' + str(rp_list))
Esempio n. 9
0
def validation(model, criterion, evaluation_loader, converter, opt):
    """ validation or evaluation """
    print('start validation')
    for p in model.parameters():
        p.requires_grad = False

    n_correct = 0
    norm_ED = 0
    length_of_data = 0
    infer_time = 0
    valid_loss_avg = Averager()

    for i, (image_tensors, labels) in enumerate(evaluation_loader):
        batch_size = image_tensors.size(0)
        length_of_data = length_of_data + batch_size
        image = image_tensors.cuda()
        # For max length prediction
        length_for_pred = torch.cuda.IntTensor([opt.batch_max_length] * batch_size)
        text_for_pred = torch.cuda.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0)

        if 'SRN' in opt.Prediction:
            text_for_loss, length_for_loss = converter.encode(labels)
        else:
            text_for_loss, length_for_loss = converter.encode(labels)

        start_time = time.time()
        if 'CTC' in opt.Prediction:
            preds = model(image, text_for_pred).log_softmax(2)
            forward_time = time.time() - start_time

            # Calculate evaluation loss for CTC deocder.
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            preds = preds.permute(1, 0, 2)  # to use CTCloss format
            cost = criterion(preds, text_for_loss, preds_size, length_for_loss)

            # Select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            preds_index = preds_index.transpose(1, 0).contiguous().view(-1)
            preds_str = converter.decode(preds_index.data, preds_size.data)

        elif 'Bert' in opt.Prediction:
            with torch.no_grad():
                pad_mask = None
                preds = model(image, pad_mask)
                forward_time = time.time() - start_time

                cost = criterion(preds[0].view(-1, preds[0].shape[-1]), text_for_loss.contiguous().view(-1)) + \
                       criterion(preds[1].view(-1, preds[1].shape[-1]), text_for_loss.contiguous().view(-1))

                # select max probabilty (greedy decoding) then decode index to character
                _, preds_index = preds[1].max(2)
                length_for_pred = torch.cuda.IntTensor([preds_index.size(-1)] * batch_size)
                preds_str = converter.decode(preds_index, length_for_pred)
                labels = converter.decode(text_for_loss, length_for_loss)

        elif 'SRN' in opt.Prediction:
            with torch.no_grad():
                preds = model(image, None)
                forward_time = time.time() - start_time

                cost, train_correct = criterion(preds, text_for_loss, opt.SRN_PAD)

                # select max probabilty (greedy decoding) then decode index to character
                _, preds_index = preds[2].max(2)
                preds_str = converter.decode(preds_index, length_for_pred)
                labels = converter.decode(text_for_loss, length_for_loss)

        else:
            preds = model(image, text_for_pred, is_train=False)
            forward_time = time.time() - start_time

            preds = preds[:, :text_for_loss.shape[1] - 1, :]
            target = text_for_loss[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.contiguous().view(-1, preds.shape[-1]), target.contiguous().view(-1))

            # select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            preds_str = converter.decode(preds_index, length_for_pred)
            labels = converter.decode(text_for_loss[:, 1:], length_for_loss)

        infer_time += forward_time
        valid_loss_avg.add(cost)

        # calculate accuracy.
        for pred, gt in zip(preds_str, labels):
            if 'Attn' in opt.Prediction:
                pred = pred[:pred.find('[s]')]  # prune after "end of sentence" token ([s])
                gt = gt[:gt.find('[s]')]

            if pred == gt:
                n_correct += 1
            else:
                temp = 1

            if len(gt) == 0:
                norm_ED += 1
            else:
                norm_ED += edit_distance(pred, gt) / len(gt)

    accuracy = n_correct / float(length_of_data) * 100

    return valid_loss_avg.val(), accuracy, norm_ED, preds_str, labels, infer_time, length_of_data
Esempio n. 10
0
def train(opt):
    """ tensorboard writer """
    writer = SummaryWriter()
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset = hierarchical_dataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    print('-' * 80)
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    model = torch.nn.DataParallel(model).to(device)
    model.train()
    if opt.saved_model != '':
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_model), strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_model))
    print("Model:")
    print(model)
    """ setup loss """
    if 'CTC' in opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
            device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    print(optimizer)
    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.experiment_name}/opt.txt',
              'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0
    if opt.saved_model != '':
        start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
        print(f'continue to train, start_iter: {start_iter}')

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = 1e+6
    i = start_iter

    writer.close()

    while (True):
        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        text, length = converter.encode(labels,
                                        batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)

        if 'CTC' in opt.Prediction:
            preds = model(image, text).log_softmax(2)
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            preds = preds.permute(1, 0, 2)

            # (ctc_a) For PyTorch 1.2.0 and 1.3.0. To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
            # https://github.com/jpuigcerver/PyLaia/issues/16
            torch.backends.cudnn.enabled = False
            cost = criterion(preds, text.to(device), preds_size.to(device),
                             length.to(device))
            torch.backends.cudnn.enabled = True

            # # (ctc_b) To reproduce our pretrained model / paper, use our previous code (below code) instead of (ctc_a).
            # # With PyTorch 1.2.0, the below code occurs NAN, so you may use PyTorch 1.1.0.
            # # Thus, the result of CTCLoss is different in PyTorch 1.1.0 and PyTorch 1.2.0.
            # # See https://github.com/clovaai/deep-text-recognition-benchmark/issues/56#issuecomment-526490707
            # cost = criterion(preds, text, preds_size, length)

        else:
            preds = model(image, text[:, :-1])  # align with Attention.forward
            target = text[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.view(-1, preds.shape[-1]),
                             target.contiguous().view(-1))

        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)
        writer.add_scalar('train_loss', loss_avg.val(), i)

        # validation part
        if i % opt.valInterval == 0:
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.experiment_name}/log_train.txt',
                      'a') as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, criterion, valid_loader, converter, opt)
                model.train()

                # training loss and validation loss
                writer.add_scalar('valid_loss', valid_loss, i)
                writer.add_scalar('elapsed_time', elapsed_time, i)
                loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                print(loss_log)
                log.write(loss_log + '\n')
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'
                print(current_model_log)
                log.write(current_model_log + '\n')

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.experiment_name}/best_accuracy.pth'
                    )

                if current_norm_ED < best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.experiment_name}/best_norm_ED.pth'
                    )
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'
                print(best_model_log)
                log.write(best_model_log + '\n')

                # show some predicted results
                print('-' * 80)
                print(
                    f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                )
                log.write(
                    f'{"Ground Truth":25s} | {"Prediction":25s} | {"Confidence Score"}\n'
                )
                print('-' * 80)
                for gt, pred, confidence in zip(labels[:5], preds[:5],
                                                confidence_score[:5]):
                    if 'Attn' in opt.Prediction:
                        gt = gt[:gt.find('[s]')]
                        pred = pred[:pred.find('[s]')]

                    print(
                        f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}'
                    )
                    log.write(
                        f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                    )
                print('-' * 80)

        # save model per 1e+5 iter.
        if (i + 1) % 1e+5 == 0:
            torch.save(
                model.state_dict(),
                f'./saved_models/{opt.experiment_name}/iter_{i + 1}.pth')

        if i == opt.num_iter:
            print('end the training')
            sys.exit()
        i += 1
Esempio n. 11
0
    def train(self, epoch, dataloader):
        use_cuda = torch.cuda.is_available()
        self.net.train()

        for m in self.net.modules():
            if isinstance(m, _BatchNorm) or isinstance(m, ABN):
            # if isinstance(m, _BatchNorm) or isinstance(m, InPlaceABNSync) or isinstance(m,InPlaceABN):
                if self.args.train['freeze']:
                    m.eval()

        lr = self.getLR(epoch)
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

        loss_avg = Averager()
        lls_avg = Averager()

        startt = time.time()
        lastt0 = startt
        for batch_idx, (data, fpn_prob, fpn_coord_prob, fpn_coord_diff, fpn_diff, fpn_connects, names) in enumerate(dataloader):
            t0 = time.time()
            iter_time = t0-lastt0
            lastt0 = t0
            case_idxs = [dataloader.dataset.cases2idx[item] for item in names]
            if use_cuda:
                data = data.cuda()
                fpn_prob = [f.cuda() for f in fpn_prob]
                fpn_connects = [f.cuda() for f in fpn_connects]
                fpn_coord_prob = [f.cuda() for f in fpn_coord_prob]
                fpn_coord_diff = [f.cuda() for f in fpn_coord_diff]
                fpn_diff =  [f.cuda() for f in fpn_diff]
                case_idxs = torch.Tensor(case_idxs).cuda()

            losses, weights, pred_prob_list = self.warp(data, fpn_prob, fpn_coord_prob, fpn_coord_diff, fpn_diff, fpn_connects, case_idxs)
            # print(losses,'losses')
            # print(weights, 'weights')
            if pred_prob_list is not None:
                pred_prob_dict_pos = {}
                pred_prob_dict_neg = {}
                for pred_prob in pred_prob_list.cpu().numpy():
                    if pred_prob[0] == -1:
                        continue
                    case_idx, nodule_idx, n_weight = pred_prob
                    assert nodule_idx != 0
                    nodule_key = dataloader.dataset.cases[int(case_idx)] + '___' + str(abs(int(nodule_idx)))
                    if nodule_idx > 0:
                        if nodule_key not in pred_prob_dict_pos:
                            pred_prob_dict_pos[nodule_key] = n_weight
                        else:
                            pred_prob_dict_pos[nodule_key] = min(n_weight, pred_prob_dict_pos[nodule_key])
                    elif nodule_idx < 0:
                        if nodule_key not in pred_prob_dict_neg:
                            pred_prob_dict_neg[nodule_key] = n_weight
                        else:
                            pred_prob_dict_neg[nodule_key] = max(n_weight, pred_prob_dict_neg[nodule_key])
                for nodule_key, n_weight in pred_prob_dict_pos.items():
                    assert nodule_key in dataloader.dataset.sample_weights
                    dataloader.dataset.sample_weights[nodule_key][0] = n_weight
                    if n_weight > self.pos_weight_thresh:
                        dataloader.dataset.sample_weights[nodule_key][2] += 1
                        if dataloader.dataset.sample_weights[nodule_key][2] >= 3:
                            case_name, nodule_idx = nodule_key.split('___')
                            dataloader.dataset.lab_buffers[case_name][int(nodule_idx)-1][5] = 0

                #for nodule_key, n_weight in pred_prob_dict_neg.items():
                #    assert nodule_key in dataloader.dataset.neg_sample_weights
                #    dataloader.dataset.neg_sample_weights[nodule_key][0] = n_weight

            losses = losses.sum(dim=0)
            weights = weights.sum(dim=0)
            if weights.shape[0] > losses.shape[0]:
                assert weights.shape[0] == losses.shape[0] * 2
                fack_weights = weights[losses.shape[0]:]
                weights = weights[:losses.shape[0]]
            else:
                fack_weights = None
            total_loss = 0
            loss_list = []
            if fack_weights is not None:
                for l, w, fw in zip(losses, weights, fack_weights):
                    l_tmp = (l/ (1e-3+w))
                    total_loss += l_tmp
                    fack_l_tmp = (l/ (1e-3+fw))
                    loss_list.append(fack_l_tmp.detach().cpu().numpy())
            else:
                for l, w in zip(losses, weights):
                    l_tmp = (l/ (1e-3+w))
                    total_loss += l_tmp
                    loss_list.append(l_tmp.detach().cpu().numpy())

            loss_avg.update(total_loss.detach().cpu().numpy())
            info = 'end %d out of %d, '%(batch_idx, len(dataloader))

            

            for lid, l in enumerate(loss_list):
                info += 'loss %d: %.4f, '%(lid, np.mean(l))
            info += 'time: %.2f' %iter_time
            print(info)
            lls_avg.update(tuple(loss_list))
            self.optimizer.zero_grad()
            loss_scalar = total_loss
            if self.half:
                self.optimizer.backward(loss_scalar)
                #self.optimizer.clip_master_grads(1)
            else:
                loss_scalar.backward()
            if self.args.clip:
                torch.nn.utils.clip_grad_value_(self.warp.parameters(),1)
            self.optimizer.step()
        endt = time.time()
        self.writeLossLog('Train', epoch, meanloss = loss_avg.val(), loss_list = lls_avg.val(), lr = lr, time=(endt-startt)/60)

        return lls_avg.val()
Esempio n. 12
0
def validation(model, criterion, evaluation_loader, converter, opt):
    """ validation or evaluation """
    n_correct = 0
    norm_ED = 0
    length_of_data = 0
    infer_time = 0
    valid_loss_avg = Averager()

    for i, (image_tensors, labels) in enumerate(evaluation_loader):
        batch_size = image_tensors.size(0)
        length_of_data = length_of_data + batch_size
        image = image_tensors.to(device)
        # For max length prediction
        length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                          batch_size).to(device)
        text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                         1).fill_(0).to(device)

        text_for_loss, length_for_loss = converter.encode(
            labels, batch_max_length=opt.batch_max_length)

        start_time = time.time()
        # if 'CTC' in opt.Prediction:
        preds = model.module.inference(image, text_for_pred)
        forward_time = time.time() - start_time

        # Calculate evaluation loss for CTC deocder.
        preds_size = torch.IntTensor([preds.size(1)] * batch_size)
        # permute 'preds' to use CTCloss format
        if opt.baiduCTC:
            if opt.label_smooth:
                cost = criterion(preds.permute(1, 0, 2), text_for_loss,
                                 preds_size, length_for_loss, batch_size)
            else:
                cost = criterion(preds.permute(1, 0, 2), text_for_loss,
                                 preds_size, length_for_loss) / batch_size
        else:
            cost = criterion(
                preds.log_softmax(2).permute(1, 0, 2), text_for_loss,
                preds_size, length_for_loss)

        # Select max probabilty (greedy decoding) then decode index to character
        if opt.baiduCTC:
            _, preds_index = preds.max(2)
            preds_index = preds_index.view(-1)
        else:
            _, preds_index = preds.max(2)
        preds_str = converter.decode(preds_index.data, preds_size.data)

        # else:
        #     preds = model(image, text_for_pred, is_train=False)
        #     forward_time = time.time() - start_time

        #     preds = preds[:, :text_for_loss.shape[1] - 1, :]
        #     target = text_for_loss[:, 1:]  # without [GO] Symbol
        #     cost = criterion(preds.contiguous().view(-1, preds.shape[-1]), target.contiguous().view(-1))

        #     # select max probabilty (greedy decoding) then decode index to character
        #     _, preds_index = preds.max(2)
        #     preds_str = converter.decode(preds_index, length_for_pred)
        #     labels = converter.decode(text_for_loss[:, 1:], length_for_loss)

        infer_time += forward_time
        valid_loss_avg.add(cost)

        # calculate accuracy & confidence score
        preds_prob = F.softmax(preds, dim=2)
        preds_max_prob, _ = preds_prob.max(dim=2)
        confidence_score_list = []
        for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob):
            # if 'Attn' in opt.Prediction:
            #     gt = gt[:gt.find('[s]')]
            #     pred_EOS = pred.find('[s]')
            #     pred = pred[:pred_EOS]  # prune after "end of sentence" token ([s])
            #     pred_max_prob = pred_max_prob[:pred_EOS]

            # To evaluate 'case sensitive model' with alphanumeric and case insensitve setting.
            if opt.sensitive and opt.data_filtering_off:
                pred = pred.lower()
                gt = gt.lower()
                alphanumeric_case_insensitve = '0123456789abcdefghijklmnopqrstuvwxyz'
                out_of_alphanumeric_case_insensitve = f'[^{alphanumeric_case_insensitve}]'
                pred = re.sub(out_of_alphanumeric_case_insensitve, '', pred)
                gt = re.sub(out_of_alphanumeric_case_insensitve, '', gt)

            if pred == gt:
                n_correct += 1

            # ICDAR2019 Normalized Edit Distance
            if len(gt) == 0 or len(pred) == 0:
                norm_ED += 0
            elif len(gt) > len(pred):
                norm_ED += 1 - edit_distance(pred, gt) / len(gt)
            else:
                norm_ED += 1 - edit_distance(pred, gt) / len(pred)

            # calculate confidence score (= multiply of pred_max_prob)
            try:
                confidence_score = pred_max_prob.cumprod(dim=0)[-1]
            except:
                confidence_score = 0  # for empty pred case, when prune after "end of sentence" token ([s])
            confidence_score_list.append(confidence_score)
            # print(pred, gt, pred==gt, confidence_score)

    accuracy = n_correct / float(length_of_data) * 100
    norm_ED = norm_ED / float(
        length_of_data)  # ICDAR2019 Normalized Edit Distance

    return valid_loss_avg.val(
    ), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data
Esempio n. 13
0
def validation_ctc_and_attn(model, criterion_ctc, criterion_attn,
                            evaluation_loader, converter_ctc, converter_attn,
                            opt):
    """ validation or evaluation """
    n_correct_all = 0
    n_correct_all_attn = 0
    norm_ED = 0
    length_of_data = 0
    infer_time = 0
    valid_loss_avg_ctc = Averager()
    valid_loss_avg_attn = Averager()

    for i, (image_tensors, labels) in enumerate(evaluation_loader):
        batch_size = image_tensors.size(0)
        length_of_data = length_of_data + batch_size
        image = image_tensors.to(device)
        # For max length prediction
        length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                          batch_size).to(device)
        text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                         1).fill_(0).to(device)

        text_for_loss_ctc, length_for_loss_ctc = converter_ctc.encode(
            labels, batch_max_length=opt.batch_max_length)
        text_for_loss_attn, length_for_loss_attn = converter_attn.encode(
            labels, batch_max_length=opt.batch_max_length)

        start_time = time.time()
        # if 'CTC' in opt.Prediction:
        preds = model.module.inference(image, text_for_pred)
        forward_time = time.time() - start_time
        preds_ctc, preds_attn = model(image,
                                      text_for_loss_attn,
                                      is_train=False)
        preds_attn = preds_attn[:, :text_for_loss_attn.shape[1] - 1, :]
        target = text_for_loss_attn[:, 1:]
        cost_attn = criterion_attn(
            preds_attn.contiguous().view(-1, preds_attn.shape[-1]),
            target.contiguous().view(-1))
        _, preds_index_attn = preds_attn.max(2)
        preds_str_attn = converter_attn.decode(preds_index_attn,
                                               length_for_pred)
        labels_attn = converter_attn.decode(text_for_loss_attn[:, 1:],
                                            length_for_loss_attn)
        # Calculate evaluation loss for CTC deocder.
        preds_size = torch.IntTensor([preds_ctc.size(1)] * batch_size)
        # permute 'preds' to use CTCloss format
        if opt.baiduCTC:
            cost_ctc = criterion_ctc(preds_ctc.permute(
                1, 0, 2), text_for_loss_ctc, preds_size,
                                     length_for_loss_ctc) / batch_size
        else:
            cost_ctc = criterion_ctc(
                preds_ctc.log_softmax(2).permute(1, 0, 2), text_for_loss_ctc,
                preds_size, length_for_loss_ctc)

        # Select max probabilty (greedy decoding) then decode index to character
        if opt.baiduCTC:
            _, preds_index = preds_ctc.max(2)
            preds_index = preds_index.view(-1)
        else:
            _, preds_index = preds_ctc.max(2)
        preds_str = converter_ctc.decode(preds_index.data, preds_size.data)

        # else:
        #     preds = model(image, text_for_pred, is_train=False)
        #     forward_time = time.time() - start_time

        #     preds = preds[:, :text_for_loss.shape[1] - 1, :]
        #     target = text_for_loss[:, 1:]  # without [GO] Symbol
        #     cost = criterion(preds.contiguous().view(-1, preds.shape[-1]), target.contiguous().view(-1))

        #     # select max probabilty (greedy decoding) then decode index to character
        #     _, preds_index = preds.max(2)
        #     preds_str = converter.decode(preds_index, length_for_pred)
        #     labels = converter.decode(text_for_loss[:, 1:], length_for_loss)

        infer_time += forward_time
        valid_loss_avg_ctc.add(cost_ctc)
        valid_loss_avg_attn.add(cost_attn)
        # calculate accuracy & confidence score
        preds_prob = F.softmax(preds_ctc, dim=2)
        preds_max_prob, _ = preds_prob.max(dim=2)
        preds_prob_attn = F.softmax(preds_attn, dim=2)
        preds_max_prob_attn, _ = preds_prob_attn.max(dim=2)
        # confidence_score_list = []

        n_correct, confidence_score_list, norm_ED = get_res(
            labels, preds_str, preds_max_prob, opt, length_of_data)
        n_correct_attn, confidence_score_list_attn, norm_ED_attn = get_res(
            labels,
            preds_str_attn,
            preds_max_prob_attn,
            opt,
            length_of_data,
            isattn=True)
        n_correct_all += n_correct
        n_correct_all_attn += n_correct_attn
    accuracy = n_correct_all / float(length_of_data) * 100
    norm_ED = norm_ED / float(
        length_of_data)  # ICDAR2019 Normalized Edit Distance
    accuracy_attn = n_correct_all_attn / float(length_of_data) * 100
    norm_ED_attn = norm_ED_attn / float(length_of_data)
    print(n_correct)
    print(n_correct_attn)
    print(length_of_data)
    # print(infer_time / float(length_of_data) )
    return valid_loss_avg_ctc.val(
    ), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data, valid_loss_avg_attn.val(
    ), accuracy_attn, norm_ED_attn, preds_str_attn, confidence_score_list_attn
Esempio n. 14
0
def train(opt):
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    """ model configuration """
    # CTCLoss
    converter_ctc = CTCLabelConverter(opt.character)
    # Attention
    converter_atten = AttnLabelConverter(opt.character)
    opt.num_class_ctc = len(converter_ctc.character)
    opt.num_class_atten = len(converter_atten.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class_ctc, opt.num_class_atten, opt.batch_max_length,
          opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling,
          opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p_: p_.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    print(optimizer)

    # use fp16 to train
    model = model.to(device)
    if opt.fp16:
        with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log:
            log.write('==> Enable fp16 training' + '\n')
        print('==> Enable fp16 training')
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

    # data parallel for multi-GPU
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model).to(device)
    model.train()
    # for i in model.module.Prediction_atten:
    #     i.to(device)
    # for i in model.module.Feat_Extraction.scr:
    #     i.to(device)
    if opt.saved_model != '':
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_model), strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_model))
    print("Model:")
    print(model)
    """ setup loss """
    criterion_ctc = torch.nn.CTCLoss(zero_infinity=True).to(device)
    criterion_atten = torch.nn.CrossEntropyLoss(ignore_index=0).to(
        device)  # ignore [GO] token = ignore index 0

    # loss averager
    loss_avg = Averager()
    """ final options """
    writer = SummaryWriter(f'./saved_models/{opt.exp_name}')
    # print(opt)
    with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0
    if opt.saved_model != '':
        try:
            start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    iteration = start_iter

    # image_tensors, labels = train_dataset.get_batch()
    while True:
        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        batch_size = image.size(0)
        text_ctc, length_ctc = converter_ctc.encode(
            labels, batch_max_length=opt.batch_max_length)
        text_atten, length_atten = converter_atten.encode(
            labels, batch_max_length=opt.batch_max_length)

        # type tuple; (tensor, list);         text_atten[:, :-1]:align with Attention.forward
        preds_ctc, preds_atten = model(image, text_atten[:, :-1])
        # CTC Loss
        preds_size = torch.IntTensor([preds_ctc.size(1)] * batch_size)
        # _, preds_index = preds_ctc.max(2)
        # preds_str_ctc = converter_ctc.decode(preds_index.data, preds_size.data)
        preds_ctc = preds_ctc.log_softmax(2).permute(1, 0, 2)
        cost_ctc = 0.1 * criterion_ctc(preds_ctc, text_ctc, preds_size,
                                       length_ctc)

        # Attention Loss
        # preds_atten = [i[:, :text_atten.shape[1] - 1, :] for i in preds_atten]
        # # select max probabilty (greedy decoding) then decode index to character
        # preds_index_atten = [i.max(2)[1] for i in preds_atten]
        # length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
        # preds_str_atten = [converter_atten.decode(i, length_for_pred) for i in preds_index_atten]
        # preds_str_atten2 = preds_str_atten
        # preds_str_atten = []
        # for i in preds_str_atten2:  # prune after "end of sentence" token ([s])
        #     temp = []
        #     for j in i:
        #         j = j[:j.find('[s]')]
        #         temp.append(j)
        #     preds_str_atten.append(temp)
        # preds_str_atten = [j[:j.find('[s]')] for i in preds_str_atten for j in i]
        target = text_atten[:, 1:]  # without [GO] Symbol
        # cost_atten = 1.0 * criterion_atten(preds_atten.view(-1, preds_atten.shape[-1]), target.contiguous().view(-1))
        for index, pred in enumerate(preds_atten):
            if index == 0:
                cost_atten = 1.0 * criterion_atten(
                    pred.view(-1, pred.shape[-1]),
                    target.contiguous().view(-1))
            else:
                cost_atten += 1.0 * criterion_atten(
                    pred.view(-1, pred.shape[-1]),
                    target.contiguous().view(-1))
        # cost_atten = [1.0 * criterion_atten(pred.view(-1, pred.shape[-1]), target.contiguous().view(-1)) for pred in
        #               preds_atten]
        # cost_atten = criterion_atten(preds_atten.view(-1, preds_atten.shape[-1]), target.contiguous().view(-1))
        cost = cost_ctc + cost_atten
        writer.add_scalar('loss', cost.item(), global_step=iteration + 1)

        # cost = cost_ctc
        # cost = cost_atten
        if (iteration + 1) % 100 == 0:
            print('\riter: {:4d}\tloss: {:6.3f}\tavg: {:6.3f}'.format(
                iteration + 1, cost.item(), loss_avg.val()),
                  end='\n')
        else:
            print('\riter: {:4d}\tloss: {:6.3f}\tavg: {:6.3f}'.format(
                iteration + 1, cost.item(), loss_avg.val()),
                  end='')
        sys.stdout.flush()
        if cost < 0.001:
            print(f'iter: {iteration + 1}\tloss: {cost}')
            # aaaaaa = 0

        # model.zero_grad()
        optimizer.zero_grad()
        if torch.isnan(cost):
            print(f'iter: {iteration + 1}\tloss: {cost}\t==> Loss is NAN')
            sys.exit()
        elif torch.isinf(cost):
            print(f'iter: {iteration + 1}\tloss: {cost}\t==> Loss is INF')
            sys.exit()
        else:
            if opt.fp16:
                with amp.scale_loss(cost, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                cost.backward()
            torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)
        writer.add_scalar('loss_avg',
                          loss_avg.val(),
                          global_step=iteration + 1)
        # if loss_avg.val() <= 0.6:
        #     opt.grad_clip = 2
        # if loss_avg.val() <= 0.3:
        #     opt.grad_clip = 1

        # validation part
        if iteration == 0 or (
                iteration + 1
        ) % opt.valInterval == 0:  # To see training progress, we also conduct validation when 'iteration == 0'
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.exp_name}/log_train.txt',
                      'a') as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, criterion_atten, valid_loader, converter_atten,
                        opt)
                model.train()
                writer.add_scalar('accuracy',
                                  current_accuracy,
                                  global_step=iteration + 1)

                # training loss and validation loss
                loss_log = f'[{iteration + 1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.exp_name}/best_accuracy.pth')
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.exp_name}/best_norm_ED.pth')
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                for gt, pred, confidence in zip(labels[:5], preds[:5],
                                                confidence_score[:5]):
                    gt = gt[:gt.find('[s]')]
                    pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')

        # save model per 1e+5 iter.
        if (iteration + 1) % 1e+5 == 0:
            torch.save(
                model.state_dict(),
                f'./saved_models/{opt.exp_name}/iter_{iteration + 1}.pth')

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()

        # if (iteration + 1) % opt.valInterval == 0:
        #     print(f'iter: {iteration + 1}\tloss: {cost}')
        iteration += 1
def train(opt):
    lib.print_model_settings(locals().copy())

    if 'Attn' in opt.Prediction:
        converter = AttnLabelConverter(opt.character)
    else:
        converter = CTCLabelConverter(opt.character)
    opt.classes = converter.character
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    log = open(os.path.join(opt.exp_dir, opt.exp_name, 'log_dataset.txt'), 'a')
    AlignCollate_valid = AlignPHOCCollate(imgH=opt.imgH,
                                          imgW=opt.imgW,
                                          keep_ratio_with_pad=opt.PAD)

    train_dataset = LmdbStylePHOCDataset(root=opt.train_data, opt=opt)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=opt.batch_size *
        2,  #*2 to sample different images from training encoder and discriminator real images
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True,
        drop_last=True)

    print('-' * 80)

    valid_dataset = LmdbStylePHOCDataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size *
        2,  #*2 to sample different images from training encoder and discriminator real images
        shuffle=
        False,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True,
        drop_last=True)

    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()

    phoc_dataset = phoc_gen(opt)
    phoc_loader = torch.utils.data.DataLoader(phoc_dataset,
                                              batch_size=opt.batch_size,
                                              shuffle=True,
                                              num_workers=int(opt.workers),
                                              pin_memory=True,
                                              drop_last=True)
    opt.num_class = len(converter.character)

    if opt.zAlone:
        genModel = styleGANGen(opt.size,
                               opt.latent,
                               opt.latent,
                               opt.n_mlp,
                               channel_multiplier=opt.channel_multiplier)
        g_ema = styleGANGen(opt.size,
                            opt.latent,
                            opt.latent,
                            opt.n_mlp,
                            channel_multiplier=opt.channel_multiplier)
    else:
        genModel = styleGANGen(opt.size,
                               opt.latent + phoc_dataset.phoc_size,
                               opt.latent,
                               opt.n_mlp,
                               channel_multiplier=opt.channel_multiplier)
        g_ema = styleGANGen(opt.size,
                            opt.latent + phoc_dataset.phoc_size,
                            opt.latent,
                            opt.n_mlp,
                            channel_multiplier=opt.channel_multiplier)
    disEncModel = styleGANDis(opt.size,
                              channel_multiplier=opt.channel_multiplier,
                              input_dim=opt.input_channel,
                              code_s_dim=phoc_dataset.phoc_size)

    accumulate(g_ema, genModel, 0)

    uCriterion = torch.nn.MSELoss()
    sCriterion = torch.nn.MSELoss()

    genModel = torch.nn.DataParallel(genModel).to(device)
    g_ema = torch.nn.DataParallel(g_ema).to(device)
    genModel.train()
    g_ema.eval()

    disEncModel = torch.nn.DataParallel(disEncModel).to(device)
    disEncModel.train()

    g_reg_ratio = opt.g_reg_every / (opt.g_reg_every + 1)
    d_reg_ratio = opt.d_reg_every / (opt.d_reg_every + 1)

    optimizer = optim.Adam(
        genModel.parameters(),
        lr=opt.lr * g_reg_ratio,
        betas=(0**g_reg_ratio, 0.99**g_reg_ratio),
    )
    dis_optimizer = optim.Adam(
        disEncModel.parameters(),
        lr=opt.lr * d_reg_ratio,
        betas=(0**d_reg_ratio, 0.99**d_reg_ratio),
    )

    ## Loading pre-trained files
    if opt.modelFolderFlag:
        if len(
                glob.glob(
                    os.path.join(opt.exp_dir, opt.exp_name,
                                 "iter_*_synth.pth"))) > 0:
            opt.saved_synth_model = glob.glob(
                os.path.join(opt.exp_dir, opt.exp_name,
                             "iter_*_synth.pth"))[-1]

    # if opt.saved_ocr_model !='' and opt.saved_ocr_model !='None':
    #     print(f'loading pretrained ocr model from {opt.saved_ocr_model}')
    #     checkpoint = torch.load(opt.saved_ocr_model)
    #     ocrModel.load_state_dict(checkpoint)

    # if opt.saved_gen_model !='' and opt.saved_gen_model !='None':
    #     print(f'loading pretrained gen model from {opt.saved_gen_model}')
    #     checkpoint = torch.load(opt.saved_gen_model, map_location=lambda storage, loc: storage)
    #     genModel.module.load_state_dict(checkpoint['g'])
    #     g_ema.module.load_state_dict(checkpoint['g_ema'])

    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        print(f'loading pretrained synth model from {opt.saved_synth_model}')
        checkpoint = torch.load(opt.saved_synth_model)

        # styleModel.load_state_dict(checkpoint['styleModel'])
        # mixModel.load_state_dict(checkpoint['mixModel'])
        genModel.load_state_dict(checkpoint['genModel'])
        g_ema.load_state_dict(checkpoint['g_ema'])
        disEncModel.load_state_dict(checkpoint['disEncModel'])

        optimizer.load_state_dict(checkpoint["optimizer"])
        dis_optimizer.load_state_dict(checkpoint["dis_optimizer"])

    # if opt.imgReconLoss == 'l1':
    #     recCriterion = torch.nn.L1Loss()
    # elif opt.imgReconLoss == 'ssim':
    #     recCriterion = ssim
    # elif opt.imgReconLoss == 'ms-ssim':
    #     recCriterion = msssim

    # loss averager
    loss_avg_dis = Averager()
    loss_avg_gen = Averager()
    loss_avg_unsup = Averager()
    loss_avg_sup = Averager()
    log_r1_val = Averager()
    log_avg_path_loss_val = Averager()
    log_avg_mean_path_length_avg = Averager()
    log_ada_aug_p = Averager()
    """ final options """
    with open(os.path.join(opt.exp_dir, opt.exp_name, 'opt.txt'),
              'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0

    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        try:
            start_iter = int(
                opt.saved_synth_model.split('_')[-2].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    #get schedulers
    scheduler = get_scheduler(optimizer, opt)
    dis_scheduler = get_scheduler(dis_optimizer, opt)

    start_time = time.time()
    iteration = start_iter
    cntr = 0

    mean_path_length = 0
    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    # loss_dict = {}

    accum = 0.5**(32 / (10 * 1000))
    ada_augment = torch.tensor([0.0, 0.0], device=device)
    ada_aug_p = opt.augment_p if opt.augment_p > 0 else 0.0
    ada_aug_step = opt.ada_target / opt.ada_length
    r_t_stat = 0

    # sample_z = torch.randn(opt.n_sample, opt.latent, device=device)

    while (True):
        # print(cntr)
        # train part
        if opt.lr_policy != "None":
            scheduler.step()
            dis_scheduler.step()

        image_input_tensors, _, labels_1, _, phoc_1, _ = iter(
            train_loader).next()
        z_code, z_labels = iter(phoc_loader).next()

        image_input_tensors = image_input_tensors.to(device)
        gt_image_tensors = image_input_tensors[:opt.batch_size]
        real_image_tensors = image_input_tensors[opt.batch_size:]
        phoc_1 = phoc_1.to(device)
        gt_phoc_tensors = phoc_1[:opt.batch_size]
        labels_1 = labels_1[:opt.batch_size]
        z_code = z_code.to(device)

        requires_grad(genModel, False)
        # requires_grad(styleModel, False)
        # requires_grad(mixModel, False)
        requires_grad(disEncModel, True)

        text_1, length_1 = converter.encode(
            labels_1, batch_max_length=opt.batch_max_length)

        style = mixing_noise(z_code, opt.batch_size, opt.latent, opt.mixing,
                             device)
        if opt.zAlone:
            #to validate orig style gan results
            newstyle = []
            newstyle.append(style[0][:, :opt.latent])
            if len(style) > 1:
                newstyle.append(style[1][:, :opt.latent])
            style = newstyle

        fake_img, _ = genModel(style, input_is_latent=opt.input_latent)

        #unsupervised code prediction on generated image
        u_pred_code = disEncModel(fake_img, mode='enc')
        uCost = uCriterion(u_pred_code, z_code)

        #supervised code prediction on gt image
        s_pred_code = disEncModel(gt_image_tensors, mode='enc')
        sCost = uCriterion(s_pred_code, gt_phoc_tensors)

        #Domain discriminator
        fake_pred = disEncModel(fake_img)
        real_pred = disEncModel(real_image_tensors)
        disCost = d_logistic_loss(real_pred, fake_pred)

        dis_enc_cost = disCost + opt.gamma_e * uCost + opt.beta * sCost

        loss_avg_dis.add(disCost)
        loss_avg_sup.add(opt.beta * sCost)
        loss_avg_unsup.add(opt.gamma_e * uCost)

        disEncModel.zero_grad()
        dis_enc_cost.backward()
        dis_optimizer.step()

        d_regularize = cntr % opt.d_reg_every == 0

        if d_regularize:
            real_image_tensors.requires_grad = True

            real_pred = disEncModel(real_image_tensors)

            r1_loss = d_r1_loss(real_pred, real_image_tensors)

            disEncModel.zero_grad()
            (opt.r1 / 2 * r1_loss * opt.d_reg_every +
             0 * real_pred[0]).backward()

            dis_optimizer.step()

        # loss_dict["r1"] = r1_loss

        # [Word Generator] update
        image_input_tensors, _, labels_1, _, phoc_1, _ = iter(
            train_loader).next()
        z_code, z_labels = iter(phoc_loader).next()

        image_input_tensors = image_input_tensors.to(device)
        gt_image_tensors = image_input_tensors[:opt.batch_size]
        real_image_tensors = image_input_tensors[opt.batch_size:]
        phoc_1 = phoc_1.to(device)
        gt_phoc_tensors = phoc_1[:opt.batch_size]
        labels_1 = labels_1[:opt.batch_size]
        z_code = z_code.to(device)

        requires_grad(genModel, True)
        requires_grad(disEncModel, False)

        text_1, length_1 = converter.encode(
            labels_1, batch_max_length=opt.batch_max_length)

        style = mixing_noise(z_code, opt.batch_size, opt.latent, opt.mixing,
                             device)
        if opt.zAlone:
            #to validate orig style gan results
            newstyle = []
            newstyle.append(style[0][:, :opt.latent])
            if len(style) > 1:
                newstyle.append(style[1][:, :opt.latent])
            style = newstyle
        fake_img, _ = genModel(style, input_is_latent=opt.input_latent)

        #unsupervised code prediction on generated image
        u_pred_code = disEncModel(fake_img, mode='enc')
        uCost = uCriterion(u_pred_code, z_code)

        fake_pred = disEncModel(fake_img)
        disGenCost = g_nonsaturating_loss(fake_pred)

        gen_enc_cost = disGenCost + opt.gamma_g * uCost
        loss_avg_gen.add(disGenCost)
        loss_avg_unsup.add(opt.gamma_g * uCost)
        # loss_dict["g"] = disGenCost

        genModel.zero_grad()
        disEncModel.zero_grad()

        gen_enc_cost.backward()
        optimizer.step()

        g_regularize = cntr % opt.g_reg_every == 0

        if g_regularize:
            image_input_tensors, _, labels_1, _, phoc_1, _ = iter(
                train_loader).next()
            z_code, z_labels = iter(phoc_loader).next()

            image_input_tensors = image_input_tensors.to(device)
            path_batch_size = max(1, opt.batch_size // opt.path_batch_shrink)

            gt_image_tensors = image_input_tensors[:path_batch_size]
            phoc_1 = phoc_1.to(device)
            gt_phoc_tensors = phoc_1[:path_batch_size]
            labels_1 = labels_1[:path_batch_size]
            z_code = z_code.to(device)
            z_code = z_code[:path_batch_size]
            z_labels = z_labels[:path_batch_size]

            text_1, length_1 = converter.encode(
                labels_1, batch_max_length=opt.batch_max_length)

            style = mixing_noise(z_code, path_batch_size, opt.latent,
                                 opt.mixing, device)
            if opt.zAlone:
                #to validate orig style gan results
                newstyle = []
                newstyle.append(style[0][:, :opt.latent])
                if len(style) > 1:
                    newstyle.append(style[1][:, :opt.latent])
                style = newstyle

            fake_img, grad = genModel(style,
                                      return_latents=True,
                                      g_path_regularize=True,
                                      mean_path_length=mean_path_length)

            decay = 0.01
            path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))

            mean_path_length_orig = mean_path_length + decay * (
                path_lengths.mean() - mean_path_length)
            path_loss = (path_lengths - mean_path_length_orig).pow(2).mean()
            mean_path_length = mean_path_length_orig.detach().item()

            # path_loss, mean_path_length, path_lengths = g_path_regularize(
            #     images_recon_2, latents, mean_path_length
            # )

            genModel.zero_grad()
            weighted_path_loss = opt.path_regularize * opt.g_reg_every * path_loss

            if opt.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            optimizer.step()

            # mean_path_length_avg = (
            #     reduce_sum(mean_path_length).item() / get_world_size()
            # )
            #commented above for multi-gpu , non-distributed setting
            mean_path_length_avg = mean_path_length

        accumulate(g_ema, genModel, accum)

        log_r1_val.add(r1_loss)
        log_avg_path_loss_val.add(path_loss)
        log_avg_mean_path_length_avg.add(torch.tensor(mean_path_length_avg))
        log_ada_aug_p.add(torch.tensor(ada_aug_p))

        if get_rank() == 0:
            if wandb and opt.wandb:
                wandb.log({
                    "Generator": g_loss_val,
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1": r1_val,
                    "Path Length Regularization": path_loss_val,
                    "Mean Path Length": mean_path_length,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                    "Path Length": path_length_val,
                })

        # validation part
        if (
                iteration + 1
        ) % opt.valInterval == 0 or iteration == 0:  # To see training progress, we also conduct validation when 'iteration == 0'

            #generate paired content with similar style
            z_code_1, z_labels_1 = iter(phoc_loader).next()
            z_code_2, z_labels_2 = iter(phoc_loader).next()
            z_code_1 = z_code_1.to(device)
            z_code_2 = z_code_2.to(device)

            style_1 = mixing_noise(z_code_1, opt.batch_size, opt.latent,
                                   opt.mixing, device)
            style_2 = []
            style_2.append(
                torch.cat((style_1[0][:, :opt.latent], z_code_2), dim=1))
            if len(style_1) > 1:
                style_2.append(
                    torch.cat((style_1[1][:, :opt.latent], z_code_2), dim=1))

            if opt.zAlone:
                #to validate orig style gan results
                newstyle = []
                newstyle.append(style_1[0][:, :opt.latent])
                if len(style_1) > 1:
                    newstyle.append(style_1[1][:, :opt.latent])
                style_1 = newstyle
                style_2 = newstyle

            fake_img_1, _ = g_ema(style_1, input_is_latent=opt.input_latent)
            fake_img_2, _ = g_ema(style_2, input_is_latent=opt.input_latent)

            os.makedirs(os.path.join(opt.trainDir, str(iteration)),
                        exist_ok=True)
            for trImgCntr in range(opt.batch_size):
                try:
                    save_image(
                        tensor2im(fake_img_1[trImgCntr].detach()),
                        os.path.join(
                            opt.trainDir, str(iteration),
                            str(trImgCntr) + '_pair1_' +
                            z_labels_1[trImgCntr] + '.png'))
                    save_image(
                        tensor2im(fake_img_2[trImgCntr].detach()),
                        os.path.join(
                            opt.trainDir, str(iteration),
                            str(trImgCntr) + '_pair2_' +
                            z_labels_2[trImgCntr] + '.png'))
                except:
                    print('Warning while saving training image')

            elapsed_time = time.time() - start_time
            # for log

            with open(os.path.join(opt.exp_dir, opt.exp_name, 'log_train.txt'),
                      'a') as log:

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}]  \
                    Train Dis loss: {loss_avg_dis.val():0.5f}, Train Gen loss: {loss_avg_gen.val():0.5f},\
                    Train UnSup loss: {loss_avg_unsup.val():0.5f}, Train Sup loss: {loss_avg_sup.val():0.5f}, \
                    Train R1-val loss: {log_r1_val.val():0.5f}, Train avg-path-loss: {log_avg_path_loss_val.val():0.5f}, \
                    Train mean-path-length loss: {log_avg_mean_path_length_avg.val():0.5f}, Train ada-aug-p: {log_ada_aug_p.val():0.5f}, \
                    Elapsed_time: {elapsed_time:0.5f}'

                #plotting
                lib.plot.plot(os.path.join(opt.plotDir, 'Train-Dis-Loss'),
                              loss_avg_dis.val().item())
                lib.plot.plot(os.path.join(opt.plotDir, 'Train-Gen-Loss'),
                              loss_avg_gen.val().item())
                lib.plot.plot(os.path.join(opt.plotDir, 'Train-UnSup-Loss'),
                              loss_avg_unsup.val().item())
                lib.plot.plot(os.path.join(opt.plotDir, 'Train-Sup-Loss'),
                              loss_avg_sup.val().item())
                lib.plot.plot(os.path.join(opt.plotDir, 'Train-r1_val'),
                              log_r1_val.val().item())
                lib.plot.plot(os.path.join(opt.plotDir, 'Train-path_loss_val'),
                              log_avg_path_loss_val.val().item())
                lib.plot.plot(
                    os.path.join(opt.plotDir, 'Train-mean_path_length_avg'),
                    log_avg_mean_path_length_avg.val().item())
                lib.plot.plot(os.path.join(opt.plotDir, 'Train-ada_aug_p'),
                              log_ada_aug_p.val().item())

                print(loss_log)

                loss_avg_dis.reset()
                loss_avg_gen.reset()
                loss_avg_unsup.reset()
                loss_avg_sup.reset()
                log_r1_val.reset()
                log_avg_path_loss_val.reset()
                log_avg_mean_path_length_avg.reset()
                log_ada_aug_p.reset()

            lib.plot.flush()

        lib.plot.tick()

        # save model per 1e+5 iter.
        if (iteration) % 1e+4 == 0:
            torch.save(
                {
                    'genModel': genModel.state_dict(),
                    'g_ema': g_ema.state_dict(),
                    'disEncModel': disEncModel.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'dis_optimizer': dis_optimizer.state_dict()
                },
                os.path.join(opt.exp_dir, opt.exp_name,
                             'iter_' + str(iteration + 1) + '_synth.pth'))

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
        cntr += 1
Esempio n. 16
0
def train(opt):
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    """ model configuration """
    if 'CTC' in opt.Prediction:
        if opt.baiduCTC:
            converter = CTCLabelConverterForBaiduWarpctc(opt.character)
        else:
            converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    model = torch.nn.DataParallel(model).to(device)
    model.train()
    if opt.saved_model != '':
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_model), strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_model))
    print("Model:")
    print(model)
    """ setup loss """
    if 'CTC' in opt.Prediction:
        if opt.baiduCTC:
            # need to install warpctc. see our guideline.
            from warpctc_pytorch import CTCLoss
            criterion = CTCLoss()
        else:
            criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
            device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()
    loss_avg2 = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    print(optimizer)
    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0
    if opt.saved_model != '':
        try:
            start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    iteration = start_iter

    while (True):
        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        # print(image.size())
        text, length = converter.encode(labels,
                                        batch_max_length=opt.batch_max_length)
        if iteration == start_iter:
            writer.add_graph(model, (image, text))
        batch_size = image.size(0)
        if 'CTC' in opt.Prediction:
            preds = model(image, text)
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            if opt.baiduCTC:
                preds = preds.permute(1, 0, 2)  # to use CTCLoss format
                cost = criterion(preds, text, preds_size, length) / batch_size
            else:
                preds = preds.log_softmax(2).permute(1, 0, 2)
                cost = criterion(preds, text, preds_size, length)

        else:
            preds = model(image, text[:, :-1])  # align with Attention.forward
            target = text[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.view(-1, preds.shape[-1]),
                             target.contiguous().view(-1))

        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)
        loss_avg2.add(cost)
        if (iteration + 1) % 100 == 0:
            writer.add_scalar("Loss/train", loss_avg2.val(),
                              (iteration + 1) // 100)
            loss_avg2.reset()

        # validation part
        if (
                iteration + 1
        ) % opt.valInterval == 0:  #or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0'
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.exp_name}/log_train.txt',
                      'a') as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, criterion, valid_loader, converter, opt)
                model.train()

                # training loss and validation loss
                train_loss = loss_avg.val()
                # writer.add_scalar("Loss/train", train_loss, (iteration + 1) // opt.valInterval)
                writer.add_scalar("Loss/valid", valid_loss,
                                  (iteration + 1) // opt.valInterval)
                writer.add_scalar("Metrics/accuracy", current_accuracy,
                                  (iteration + 1) // opt.valInterval)
                writer.add_scalar("Metrics/norm_ED", current_norm_ED,
                                  (iteration + 1) // opt.valInterval)
                loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {train_loss:0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.exp_name}/best_accuracy.pth')
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.exp_name}/best_norm_ED.pth')
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                for gt, pred, confidence in zip(labels[:5], preds[:5],
                                                confidence_score[:5]):
                    if 'Attn' in opt.Prediction:
                        gt = gt[:gt.find('[s]')]
                        pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')

        # save model per 1e+5 iter.
        if (iteration + 1) % 1e+5 == 0:
            torch.save(
                model.state_dict(),
                f'./saved_models/{opt.exp_name}/iter_{iteration+1}.pth')

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
Esempio n. 17
0
def train(opt):

    if opt.use_tb:
        tb_dir = f'/home_hongdo/{getpass.getuser()}/tb/{opt.experiment_name}'
        print('tensorboard : ', tb_dir)
        if not os.path.exists(tb_dir):
            os.makedirs(tb_dir)
        writer = SummaryWriter(log_dir=tb_dir)
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    # log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a')
    log = open(f'{save_dir}/{opt.experiment_name}/log_dataset.txt', 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    """ model configuration """
    converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3

    # sekim for transfer learning
    model = Model(opt, 38)

    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    model = torch.nn.DataParallel(model).to(device)

    if opt.saved_model != '':
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_model), strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_model))

    # sekim change last layer
    in_feature = model.module.Prediction.generator.in_features
    model.module.Prediction.attention_cell.rnn = nn.LSTMCell(
        256 + opt.num_class, 256).to(device)
    model.module.Prediction.generator = nn.Linear(in_feature,
                                                  opt.num_class).to(device)

    print(model.module.Prediction.generator)
    print("Model:")
    print(model)

    model.train()
    """ setup loss """
    criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
        device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    print(optimizer)
    """ final options """

    # with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file:
    with open(f'{save_dir}/{opt.experiment_name}/opt.txt', 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0

    if opt.saved_model != '':
        try:
            start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
            print("-------------------------------------------------")
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    i = start_iter

    while (True):
        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        text, length = converter.encode(labels,
                                        batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)

        preds = model(image, text[:, :-1])  # align with Attention.forward
        target = text[:, 1:]  # without [GO] Symbol
        cost = criterion(preds.view(-1, preds.shape[-1]),
                         target.contiguous().view(-1))

        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)

        # validation part
        if i % opt.valInterval == 0:
            elapsed_time = time.time() - start_time
            # for log
            with open(f'{save_dir}/{opt.experiment_name}/log_train.txt',
                      'a') as log:
                # with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, criterion, valid_loader, converter, opt)

                model.train()

                # training loss and validation loss

                loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                loss_avg.reset()
                if opt.use_tb:
                    writer.add_scalar('OCR_loss/train_loss', loss_avg.val(), i)
                    writer.add_scalar('OCR_loss/validation_loss', valid_loss,
                                      i)

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    # torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth')
                    torch.save(
                        model.state_dict(),
                        f'{save_dir}/{opt.experiment_name}/best_accuracy.pth')
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    # torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth')
                    torch.save(
                        model.state_dict(),
                        f'{save_dir}/{opt.experiment_name}/best_norm_ED.pth')
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                for gt, pred, confidence in zip(labels[:5], preds[:5],
                                                confidence_score[:5]):

                    gt = gt[:gt.find('[s]')]
                    pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')

        # save model per 1e+5 iter.
        if (i + 1) % 1e+5 == 0:
            # torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')
            torch.save(model.state_dict(),
                       f'{save_dir}/{opt.experiment_name}/iter_{i + 1}.pth')

        if i == opt.num_iter:
            print('end the training')
            sys.exit()
        i += 1
Esempio n. 18
0
def train(opt):
    os.makedirs(opt.log, exist_ok=True)
    writer = SummaryWriter(opt.log)
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)

    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    """ model configuration """

    ctc_converter = CTCLabelConverter(opt.character)
    attn_converter = AttnLabelConverter(opt.character)
    opt.num_class = len(attn_converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)

    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    model = torch.nn.DataParallel(model).to(device)
    model.train()
    if opt.saved_model != '':
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_model), strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_model))
    """ setup loss """
    loss_avg = Averager()
    ctc_loss = torch.nn.CTCLoss(zero_infinity=True).to(device)
    attn_loss = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0
    if opt.saved_model != '':
        try:
            start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    iteration = start_iter
    pbar = tqdm(range(opt.num_iter))

    for iteration in pbar:

        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        ctc_text, ctc_length = ctc_converter.encode(
            labels, batch_max_length=opt.batch_max_length)
        attn_text, attn_length = attn_converter.encode(
            labels, batch_max_length=opt.batch_max_length)

        batch_size = image.size(0)

        preds, refiner = model(image, attn_text[:, :-1])

        refiner_size = torch.IntTensor([refiner.size(1)] * batch_size)
        refiner = refiner.log_softmax(2).permute(1, 0, 2)
        refiner_loss = ctc_loss(refiner, ctc_text, refiner_size, ctc_length)

        total_loss = opt.lambda_ctc * refiner_loss
        target = attn_text[:, 1:]  # without [GO] Symbol
        for pred in preds:
            total_loss += opt.lambda_attn * attn_loss(
                pred.view(-1, pred.shape[-1]),
                target.contiguous().view(-1))

        model.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()
        loss_avg.add(total_loss)
        if loss_avg.val() <= 0.6:
            opt.grad_clip = 2
        if loss_avg.val() <= 0.3:
            opt.grad_clip = 1

        preds = (p.cpu() for p in preds)
        refiner = refiner.cpu()
        image = image.cpu()
        torch.cuda.empty_cache()

        writer.add_scalar('train_loss', loss_avg.val(), iteration)
        pbar.set_description('Iteration {0}/{1}, AvgLoss {2}'.format(
            iteration, opt.num_iter, loss_avg.val()))

        # validation part
        if (iteration + 1) % opt.valInterval == 0 or iteration == 0:
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.exp_name}/log_train.txt',
                      'a') as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, attn_loss, valid_loader, attn_converter, opt)
                model.train()

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                writer.add_scalar('Val_loss', valid_loss)
                pbar.set_description(loss_log)
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.exp_name}/best_accuracy_{str(best_accuracy)}.pth'
                    )
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.exp_name}/best_norm_ED.pth')
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                # print(loss_model_log)

                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                for gt, pred, confidence in zip(labels[:5], preds[:5],
                                                confidence_score[:5]):
                    if 'Attn' or 'Transformer' in opt.Prediction:
                        gt = gt[:gt.find('[s]')]
                        pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                log.write(predicted_result_log + '\n')

        # save model per 1e+3 iter.
        if (iteration + 1) % 1e+3 == 0:
            torch.save(model.state_dict(),
                       f'./saved_models/{opt.exp_name}/SCATTER_STR.pth')

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
Esempio n. 19
0
File: test.py Progetto: tnqkr98/ocr
def validation(model, criterion, evaluation_loader, converter, opt):
    """ validation or evaluation """
    n_correct = 0
    norm_ED = 0
    length_of_data = 0
    infer_time = 0
    valid_loss_avg = Averager()

    for i, (image_tensors, labels) in enumerate(evaluation_loader):
        batch_size = image_tensors.size(0)
        length_of_data = length_of_data + batch_size
        image = image_tensors.to(device)
        # For max length prediction
        length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                          batch_size).to(device)
        text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                         1).fill_(0).to(device)

        text_for_loss, length_for_loss = converter.encode(
            labels, batch_max_length=opt.batch_max_length)

        start_time = time.time()
        if 'CTC' in opt.Prediction:
            preds = model(image, text_for_pred).log_softmax(2)
            forward_time = time.time() - start_time

            # Calculate evaluation loss for CTC deocder.
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            preds = preds.permute(1, 0, 2)  # to use CTCloss format

            # To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
            # https://github.com/jpuigcerver/PyLaia/issues/16
            torch.backends.cudnn.enabled = False
            cost = criterion(preds, text_for_loss, preds_size, length_for_loss)
            torch.backends.cudnn.enabled = True

            # Select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            preds_index = preds_index.transpose(1, 0).contiguous().view(-1)
            preds_str = converter.decode(preds_index.data, preds_size.data)

        else:
            preds = model(image, text_for_pred, is_train=False)
            forward_time = time.time() - start_time

            preds = preds[:, :text_for_loss.shape[1] - 1, :]
            target = text_for_loss[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.contiguous().view(-1, preds.shape[-1]),
                             target.contiguous().view(-1))

            # select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            preds_str = converter.decode(preds_index, length_for_pred)
            labels = converter.decode(text_for_loss[:, 1:], length_for_loss)

        infer_time += forward_time
        valid_loss_avg.add(cost)

        # calculate accuracy.
        for pred, gt in zip(preds_str, labels):
            if 'Attn' in opt.Prediction:
                pred = pred[:pred.find(
                    '[s]')]  # prune after "end of sentence" token ([s])
                gt = gt[:gt.find('[s]')]

            if pred == gt:
                n_correct += 1
            if len(gt) == 0:
                norm_ED += 1
            else:
                norm_ED += edit_distance(pred, gt) / len(gt)

    accuracy = n_correct / float(length_of_data) * 100

    return valid_loss_avg.val(
    ), accuracy, norm_ED, preds_str, labels, infer_time, length_of_data
def validation(model, criterion, evaluation_loader, converter, opt):
    """ validation or evaluation """
    for p in model.parameters():
        p.requires_grad = False

    n_correct = 0
    norm_ED = 0
    max_length = 25
    length_of_data = 0
    infer_time = 0
    valid_loss_avg = Averager()

    for i, (cpu_images, cpu_texts) in enumerate(evaluation_loader):
        batch_size = cpu_images.size(0)
        length_of_data = length_of_data + batch_size
        with torch.no_grad():
            image = cpu_images.cuda()
            # For max length prediction
            length_for_pred = torch.cuda.IntTensor([max_length] * batch_size)
            text_for_pred = torch.cuda.LongTensor(batch_size,
                                                  max_length + 1).fill_(0)

            text_for_loss, length_for_loss = converter.encode(cpu_texts)

        start_time = time.time()
        if 'CTC' in opt.Prediction:
            preds = model(image, text_for_pred)
            forward_time = time.time() - start_time

            # Calculate evaluation loss for CTC deocder.
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            preds = preds.permute(1, 0, 2)  # to use CTCloss format
            cost = criterion(preds, text_for_loss, preds_size,
                             length_for_loss) / batch_size

            # Select max probabilty (greedy decoding) then decode index to character
            _, preds = preds.max(2)
            preds = preds.transpose(1, 0).contiguous().view(-1)
            sim_preds = converter.decode(preds.data, preds_size.data)

        else:
            preds = model(image, text_for_pred, is_train=False)
            forward_time = time.time() - start_time

            preds = preds[:, :text_for_loss.shape[1] - 1, :]
            target = text_for_loss[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.contiguous().view(-1, preds.shape[-1]),
                             target.contiguous().view(-1))

            # select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            sim_preds = converter.decode(preds_index, length_for_pred)
            cpu_texts = converter.decode(text_for_loss[:, 1:], length_for_loss)

        infer_time += forward_time
        valid_loss_avg.add(cost)

        # calculate accuracy.
        for pred, gt in zip(sim_preds, cpu_texts):
            if 'CTC' not in opt.Prediction:
                pred = pred[:pred.find(
                    '[s]')]  # prune after "end of sentence" token ([s])
                gt = gt[:gt.find('[s]')]

            if pred == gt:
                n_correct += 1
            norm_ED += edit_distance(pred, gt) / len(gt)

    accuracy = n_correct / float(length_of_data) * 100

    return valid_loss_avg.val(
    ), accuracy, norm_ED, sim_preds, cpu_texts, infer_time
Esempio n. 21
0
    def test(self, model, target_path, dataloader):
        save_folder = os.path.join(self.save_path, self.args.name)

        if not os.path.exists(save_folder):
            raise FileNotFoundError(f'No such folders {save_folder}')

        classes = model.classes
        model = torch.nn.DataParallel(model).to(self.device)
        model.load_state_dict(
            torch.load(os.path.join(save_folder, self.args.name + '.pth'),
                       map_location=self.device))

        loss_avg = Averager()
        calc = ScoreCalc()
        cer_avg = Averager()

        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(self.device)

        model.eval()
        pred_num = 10
        pred_result = []
        with tqdm(dataloader, unit="batch") as vepoch:
            for batch, batch_sampler in enumerate(vepoch):
                vepoch.set_description(f"Test Session / Batch {batch+1}")
                with torch.no_grad():
                    img = batch_sampler[0].to(self.device)
                    text = batch_sampler[1][0].to(self.device)
                    length = batch_sampler[1][1].to(self.device)

                    preds = model(img, text[:, :-1], max(length).cpu().numpy())
                    target = text[:, 1:]
                    v_cost = criterion(
                        preds.contiguous().view(-1, preds.shape[-1]),
                        target.contiguous().view(-1))

                    torch.nn.utils.clip_grad_norm_(
                        model.parameters(),
                        5)  # gradient clipping with 5 (Default)
                    loss_avg.add(v_cost)
                    batch_size = len(text)
                    pred_max = torch.argmax(
                        F.softmax(preds, dim=2).view(batch_size, -1, classes),
                        2)

                    calc.add(
                        target,
                        F.softmax(preds, dim=2).view(batch_size, -1, classes),
                        length)

                    word_target = dataloader.dataset.converter.decode(
                        target, length)
                    word_preds = dataloader.dataset.converter.decode(
                        pred_max, length)

                    cer_avg.add(
                        torch.from_numpy(
                            np.array(get_cer(word_preds, word_target))))
                    vepoch.set_postfix(loss=loss_avg.val().item(),
                                       acc=calc.val().item(),
                                       cer=cer_avg.val().item())

                    if batch % (len(vepoch) // 10) == 0:
                        pred = unicodedata.normalize('NFC', word_preds[0])
                        target = unicodedata.normalize('NFC', word_target[0])
                        pred_result.append(dict(target=target, pred=pred))

                    del batch_sampler, v_cost, pred_max, img, text, length

        #save_plt(xs,os.path.join(save_folder,name),0,epoch)
        log = dict()
        log['loss'] = loss_avg.val().item()
        log['acc'] = calc.val().item()
        log['cer'] = cer_avg.val().item()
        log['preds'] = pred_result

        with open(os.path.join(save_folder, f'{self.args.name}_test.log'),
                  'w') as f:
            json.dump(log, f, indent=2)
Esempio n. 22
0
def validation(model,
               criterion,
               evaluation_loader,
               converter,
               opt,
               eval_data=None):
    """ validation or evaluation """
    for p in model.parameters():
        p.requires_grad = False

    n_correct = 0
    norm_ED = 0
    max_length = opt.batch_max_length
    length_of_data = 0
    infer_time = 0
    valid_loss_avg = Averager()

    if 'Transformer' in opt.SequenceModeling:
        text_pos = torch.arange(1,
                                max_length + 2,
                                dtype=torch.long,
                                device='cuda').expand(
                                    evaluation_loader.batch_size, -1)

    for i, (image_tensors, labels) in enumerate(evaluation_loader):
        print(image_tensors.size())
        img = image_tensors[100].squeeze().mul_(0.5).add_(0.5).mul_(
            255).numpy()
        print(img.shape)
        cv2.imshow('1', img)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
        exit()
        batch_size = image_tensors.size(0)
        length_of_data = length_of_data + batch_size
        with torch.no_grad():
            image = image_tensors.cuda()
            # For max length prediction
            length_for_pred = torch.cuda.IntTensor([opt.batch_max_length] *
                                                   batch_size)
            text_for_pred = torch.cuda.LongTensor(
                batch_size, opt.batch_max_length + 1).fill_(0)

            if 'Transformer' in opt.SequenceModeling:
                text_for_loss, length_for_loss, text_pos_for_loss = converter.encode(
                    labels, opt.batch_max_length)
            elif 'CTC' in opt.Prediction:
                text_for_loss, length_for_loss = converter.encode(labels)
            else:
                text_for_loss, length_for_loss = converter.encode(
                    labels, opt.batch_max_length)

        start_time = time.time()
        if 'Transformer' in opt.SequenceModeling:
            batch_text_pos = text_pos[:batch_size]
            preds = model(image,
                          text_for_pred,
                          is_train=False,
                          tgt_pos=batch_text_pos)
            forward_time = time.time() - start_time
            # print('test pred',preds[0].size(),text_for_loss.shape[1] - 1)
            preds = preds[:, :text_for_loss.shape[1] - 1, :]

            target = text_for_loss[:, 1:]  # without [GO] Symbol
            # print('pred',preds.size(),target.size())
            # print('pred[0]',preds[0],target[0])
            cost = criterion(preds.contiguous().view(-1, preds.shape[-1]),
                             target.contiguous().view(-1))

            # select max probabilty (greedy decoding) then decode index to character
            # print('cost',cost)
            # exit()
            _, preds_index = preds.max(2)
            # print('preds_index',preds_index,length_for_pred)
            # exit()
            preds_str = converter.decode(preds_index, length_for_pred)
            labels = converter.decode(text_for_loss[:, 1:], length_for_loss)
        elif 'CTC' in opt.Prediction:
            preds = model(image, text_for_pred).log_softmax(2)
            forward_time = time.time() - start_time

            # Calculate evaluation loss for CTC deocder.
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            preds = preds.permute(1, 0, 2)  # to use CTCloss format
            cost = criterion(preds, text_for_loss, preds_size, length_for_loss)

            # Select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            preds_index = preds_index.transpose(1, 0).contiguous().view(-1)
            preds_str = converter.decode(preds_index.data, preds_size.data)

        else:
            preds = model(image, text_for_pred, is_train=False)
            forward_time = time.time() - start_time

            preds = preds[:, :text_for_loss.shape[1] - 1, :]
            target = text_for_loss[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.contiguous().view(-1, preds.shape[-1]),
                             target.contiguous().view(-1))

            # select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            preds_str = converter.decode(preds_index, length_for_pred)
            labels = converter.decode(text_for_loss[:, 1:], length_for_loss)
        print('forward_time', forward_time * 1000, 'ms')
        infer_time += forward_time
        valid_loss_avg.add(cost)

        # calculate accuracy.
        for pred, gt in zip(preds_str, labels):
            if 'Transformer' in opt.SequenceModeling:
                pred = pred[:pred.find('</s>')]
                gt = gt[:gt.find('</s>')]
            elif 'Attn' in opt.Prediction:
                # prune after "end of sentence" token ([s])
                pred = pred[:pred.find('[s]')]
                gt = gt[:gt.find('[s]')]

            if pred == gt:
                n_correct += 1
            norm_ED += edit_distance(pred, gt) / len(gt)

    accuracy = n_correct / float(length_of_data) * 100

    return valid_loss_avg.val(
    ), accuracy, norm_ED, preds_str, labels, infer_time, length_of_data
def train(opt):
    plotDir = os.path.join(opt.exp_dir,opt.exp_name,'plots')
    if not os.path.exists(plotDir):
        os.makedirs(plotDir)
    
    lib.print_model_settings(locals().copy())

    """ dataset preparation """
    if not opt.data_filtering_off:
        print('Filtering the images containing characters which are not in opt.character')
        print('Filtering the images whose label is longer than opt.batch_max_length')
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')

    # #considering the real images for discriminator
    # opt.batch_size = opt.batch_size*2

    # train_dataset = Batch_Balanced_Dataset(opt)

    log = open(os.path.join(opt.exp_dir,opt.exp_name,'log_dataset.txt'), 'a')
    AlignCollate_valid = AlignPairCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)

    train_dataset, train_dataset_log = hierarchical_dataset(root=opt.train_data, opt=opt)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=opt.batch_size,
        shuffle=True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True)
    log.write(train_dataset_log)
    print('-' * 80)

    valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=opt.batch_size,
        shuffle=False,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    
    model = AdaINGenV4(opt)
    ocrModel = Model(opt)
    disModel = MsImageDisV1(opt)
    
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
          opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
          opt.SequenceModeling, opt.Prediction)

    
    #  weight initialization
    for currModel in [model, ocrModel, disModel]:
        for name, param in currModel.named_parameters():
            if 'localization_fc2' in name:
                print(f'Skip {name} as it is already initialized')
                continue
            try:
                if 'bias' in name:
                    init.constant_(param, 0.0)
                elif 'weight' in name:
                    init.kaiming_normal_(param)
            except Exception as e:  # for batchnorm.
                if 'weight' in name:
                    param.data.fill_(1)
                continue
    
    
    # data parallel for multi-GPU
    ocrModel = torch.nn.DataParallel(ocrModel).to(device)
    if not opt.ocrFixed:
        ocrModel.train()
    else:
        ocrModel.module.Transformation.eval()
        ocrModel.module.FeatureExtraction.eval()
        ocrModel.module.AdaptiveAvgPool.eval()
        # ocrModel.module.SequenceModeling.eval()
        ocrModel.module.Prediction.eval()

    model = torch.nn.DataParallel(model).to(device)
    model.train()
    
    disModel = torch.nn.DataParallel(disModel).to(device)
    disModel.train()

    if opt.modelFolderFlag:
        
        if len(glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth")))>0:
            opt.saved_synth_model = glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth"))[-1]
        
        if len(glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_dis.pth")))>0:
            opt.saved_dis_model = glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_dis.pth"))[-1]

    #loading pre-trained model
    if opt.saved_ocr_model != '' and opt.saved_ocr_model != 'None':
        print(f'loading pretrained ocr model from {opt.saved_ocr_model}')
        if opt.FT:
            ocrModel.load_state_dict(torch.load(opt.saved_ocr_model), strict=False)
        else:
            ocrModel.load_state_dict(torch.load(opt.saved_ocr_model))
    print("OCRModel:")
    print(ocrModel)

    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        print(f'loading pretrained synth model from {opt.saved_synth_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_synth_model), strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_synth_model))
    print("SynthModel:")
    print(model)

    if opt.saved_dis_model != '' and opt.saved_dis_model != 'None':
        print(f'loading pretrained discriminator model from {opt.saved_dis_model}')
        if opt.FT:
            disModel.load_state_dict(torch.load(opt.saved_dis_model), strict=False)
        else:
            disModel.load_state_dict(torch.load(opt.saved_dis_model))
    print("DisModel:")
    print(disModel)

    """ setup loss """
    if 'CTC' in opt.Prediction:
        ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)  # ignore [GO] token = ignore index 0
    

    # recCriterion = torch.nn.L1Loss()
    # styleRecCriterion = torch.nn.L1Loss()

    if opt.imgReconLoss == 'l1':
        recCriterion = torch.nn.L1Loss()
    elif opt.imgReconLoss == 'ssim':
        recCriterion = ssim
    elif opt.imgReconLoss == 'ms-ssim':
        recCriterion = msssim

    if opt.styleLoss == 'l1':
        styleRecCriterion = torch.nn.L1Loss()
    elif opt.styleLoss == 'triplet':
        styleRecCriterion = torch.nn.TripletMarginLoss(margin=opt.tripletMargin, p=1)
    #for validation; check only positive pairs
    styleTestRecCriterion = torch.nn.L1Loss()


    # loss averager
    loss_avg_ocr = Averager()
    loss_avg = Averager()
    loss_avg_dis = Averager()

    loss_avg_ocrRecon_1 = Averager()
    loss_avg_ocrRecon_2 = Averager()
    loss_avg_gen = Averager()
    loss_avg_imgRecon = Averager()
    loss_avg_styRecon = Averager()

    ##---------------------------------------##
    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.optim=='adam':
        optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay)
    else:
        optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps, weight_decay=opt.weight_decay)
    print("SynthOptimizer:")
    print(optimizer)
    

    #filter parameters for OCR training
    ocr_filtered_parameters = []
    ocr_params_num = []
    for p in filter(lambda p: p.requires_grad, ocrModel.parameters()):
        ocr_filtered_parameters.append(p)
        ocr_params_num.append(np.prod(p.size()))
    print('OCR Trainable params num : ', sum(ocr_params_num))


    # setup optimizer
    if opt.optim=='adam':
        ocr_optimizer = optim.Adam(ocr_filtered_parameters, lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay)
    else:
        ocr_optimizer = optim.Adadelta(ocr_filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps, weight_decay=opt.weight_decay)
    print("OCROptimizer:")
    print(ocr_optimizer)

    #filter parameters for OCR training
    dis_filtered_parameters = []
    dis_params_num = []
    for p in filter(lambda p: p.requires_grad, disModel.parameters()):
        dis_filtered_parameters.append(p)
        dis_params_num.append(np.prod(p.size()))
    print('Dis Trainable params num : ', sum(dis_params_num))

    # setup optimizer
    if opt.optim=='adam':
        dis_optimizer = optim.Adam(dis_filtered_parameters, lr=opt.lr, betas=(opt.beta1, opt.beta2), weight_decay=opt.weight_decay)
    else:
        dis_optimizer = optim.Adadelta(dis_filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps, weight_decay=opt.weight_decay)
    print("DisOptimizer:")
    print(dis_optimizer)
    ##---------------------------------------##

    """ final options """
    with open(os.path.join(opt.exp_dir,opt.exp_name,'opt.txt'), 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)

    """ start training """
    start_iter = 0
    
    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        try:
            start_iter = int(opt.saved_synth_model.split('_')[-2].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    
    #get schedulers
    scheduler = get_scheduler(optimizer,opt)
    ocr_scheduler = get_scheduler(ocr_optimizer,opt)
    dis_scheduler = get_scheduler(dis_optimizer,opt)

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    best_accuracy_ocr = -1
    best_norm_ED_ocr = -1
    iteration = start_iter
    cntr=0


    while(True):
        # train part
        
        if opt.lr_policy !="None":
            scheduler.step()
            ocr_scheduler.step()
            dis_scheduler.step()

        image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(train_loader).next()
        
        cntr+=1

        image_input_tensors = image_input_tensors.to(device)
        image_gt_tensors = image_gt_tensors.to(device)
        batch_size = image_input_tensors.size(0)

        
        ##-----------------------------------##
        #generate text(labels) from ocr.forward
        if opt.ocrFixed:
            # ocrModel.eval()
            length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
            text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)
            
            if 'CTC' in opt.Prediction:
                preds = ocrModel(image_input_tensors, text_for_pred)
                preds = preds[:, :text_for_loss.shape[1] - 1, :]
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                _, preds_index = preds.max(2)
                labels_1 = converter.decode(preds_index.data, preds_size.data)
            else:
                preds = ocrModel(image_input_tensors, text_for_pred, is_train=False)
                _, preds_index = preds.max(2)
                labels_1 = converter.decode(preds_index, length_for_pred)
                for idx, pred in enumerate(labels_1):
                    pred_EOS = pred.find('[s]')
                    labels_1[idx] = pred[:pred_EOS]  # prune after "end of sentence" token ([s])
            # ocrModel.train()
        
        ##-----------------------------------##
        text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length)
        text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length)
        
        #forward pass from style and word generator
        images_recon_1, images_recon_2, style = model(image_input_tensors, text_1, text_2)


        if 'CTC' in opt.Prediction:
            
            if not opt.ocrFixed:
                #ocr training with orig image
                preds_ocr = ocrModel(image_input_tensors, text_1)
                preds_size_ocr = torch.IntTensor([preds_ocr.size(1)] * batch_size)
                preds_ocr = preds_ocr.log_softmax(2).permute(1, 0, 2)

                ocrCost_train = ocrCriterion(preds_ocr, text_1, preds_size_ocr, length_1)

            
            #content loss for reconstructed images
            preds_1 = ocrModel(images_recon_1, text_1)
            preds_size_1 = torch.IntTensor([preds_1.size(1)] * batch_size)
            preds_1 = preds_1.log_softmax(2).permute(1, 0, 2)

            preds_2 = ocrModel(images_recon_2, text_2)
            preds_size_2 = torch.IntTensor([preds_2.size(1)] * batch_size)
            preds_2 = preds_2.log_softmax(2).permute(1, 0, 2)
            ocrCost_1 = ocrCriterion(preds_1, text_1, preds_size_1, length_1)
            ocrCost_2 = ocrCriterion(preds_2, text_2, preds_size_2, length_2)
            # ocrCost = 0.5*( ocrCost_1 + ocrCost_2 )

        else:
            if not opt.ocrFixed:
                #ocr training with orig image
                preds_ocr = ocrModel(image_input_tensors, text_1[:, :-1])  # align with Attention.forward
                target_ocr = text_1[:, 1:]  # without [GO] Symbol

                ocrCost_train = ocrCriterion(preds_ocr.view(-1, preds_ocr.shape[-1]), target_ocr.contiguous().view(-1))

            #content loss for reconstructed images
            preds_1 = ocrModel(images_recon_1, text_1[:, :-1], is_train=False)  # align with Attention.forward
            target_1 = text_1[:, 1:]  # without [GO] Symbol

            preds_2 = ocrModel(images_recon_2, text_2[:, :-1], is_train=False)  # align with Attention.forward
            target_2 = text_2[:, 1:]  # without [GO] Symbol

            ocrCost_1 = ocrCriterion(preds_1.view(-1, preds_1.shape[-1]), target_1.contiguous().view(-1))
            ocrCost_2 = ocrCriterion(preds_2.view(-1, preds_2.shape[-1]), target_2.contiguous().view(-1))
            # ocrCost = 0.5*(ocrCost_1+ocrCost_2)
        
        if not opt.ocrFixed:
            #training OCR
            ocrModel.zero_grad()
            ocrCost_train.backward()
            # torch.nn.utils.clip_grad_norm_(ocrModel.parameters(), opt.grad_clip)  # gradient clipping with 5 (Default)
            ocr_optimizer.step()
            #if ocr is fixed; ignore this loss
            loss_avg_ocr.add(ocrCost_train)
        else:
            loss_avg_ocr.add(torch.tensor(0.0))

        
        #Domain discriminator: Dis update
        disModel.zero_grad()
        disCost = opt.disWeight*0.5*(disModel.module.calc_dis_loss(images_recon_1.detach(), image_input_tensors) + disModel.module.calc_dis_loss(images_recon_2.detach(), image_gt_tensors))
        disCost.backward()
        # torch.nn.utils.clip_grad_norm_(disModel.parameters(), opt.grad_clip)  # gradient clipping with 5 (Default)
        dis_optimizer.step()
        loss_avg_dis.add(disCost)
        
        # #[Style Encoder] + [Word Generator] update
        #Adversarial loss
        disGenCost = 0.5*(disModel.module.calc_gen_loss(images_recon_1)+disModel.module.calc_gen_loss(images_recon_2))

        #Input reconstruction loss
        recCost = 0.5*(recCriterion(images_recon_1,image_input_tensors) + recCriterion(images_recon_2,image_gt_tensors))

        #Pair style reconstruction loss
        if opt.styleReconWeight == 0.0:
            styleRecCost = torch.tensor(0.0)
        else:
            # if opt.styleDetach:
            #     styleRecCost = styleRecCriterion(model(images_recon_2, None, None, styleFlag=True), style.detach())
            # else:
            #     styleRecCost = styleRecCriterion(model(images_recon_2, None, None, styleFlag=True), style)
            styleRecCost = 0.33*(styleRecCriterion(model(image_gt_tensors, None, None, styleFlag=True), style) + \
                styleRecCriterion(model(images_recon_1, None, None, styleFlag=True), style) + \
                    styleRecCriterion(model(images_recon_2, None, None, styleFlag=True), style))

        #OCR Content cost
        ocrCost = 0.5*(ocrCost_1+ocrCost_2)

        cost = opt.ocrWeight*ocrCost + opt.reconWeight*recCost + opt.disWeight*disGenCost + opt.styleReconWeight*styleRecCost

        model.zero_grad()
        ocrModel.zero_grad()
        disModel.zero_grad()
        cost.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()
        loss_avg.add(cost)

        #Individual losses
        loss_avg_ocrRecon_1.add(opt.ocrWeight*0.5*ocrCost_1)
        loss_avg_ocrRecon_2.add(opt.ocrWeight*0.5*ocrCost_2)
        loss_avg_gen.add(opt.disWeight*disGenCost)
        loss_avg_imgRecon.add(opt.reconWeight*recCost)
        loss_avg_styRecon.add(opt.styleReconWeight*styleRecCost)

        # validation part
        if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' 
            
            #Save training images
            os.makedirs(os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration)), exist_ok=True)
            for trImgCntr in range(batch_size):
                try:
                    save_image(tensor2im(image_input_tensors[trImgCntr].detach()),os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration),str(trImgCntr)+'_sInput_'+labels_1[trImgCntr]+'.png'))
                    save_image(tensor2im(image_gt_tensors[trImgCntr].detach()),os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration),str(trImgCntr)+'_csInput_'+labels_2[trImgCntr]+'.png'))
                    save_image(tensor2im(images_recon_1[trImgCntr].detach()),os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration),str(trImgCntr)+'_sRecon_'+labels_1[trImgCntr]+'.png'))
                    save_image(tensor2im(images_recon_2[trImgCntr].detach()),os.path.join(opt.exp_dir,opt.exp_name,'trainImages',str(iteration),str(trImgCntr)+'_csRecon_'+labels_2[trImgCntr]+'.png'))
                except:
                    print('Warning while saving training image')
            
            elapsed_time = time.time() - start_time
            # for log
            
            with open(os.path.join(opt.exp_dir,opt.exp_name,'log_train.txt'), 'a') as log:
                model.eval()
                ocrModel.module.Transformation.eval()
                ocrModel.module.FeatureExtraction.eval()
                ocrModel.module.AdaptiveAvgPool.eval()
                ocrModel.module.SequenceModeling.eval()
                ocrModel.module.Prediction.eval()
                disModel.eval()
                
                with torch.no_grad():                    
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation_synth_v2(
                        iteration, model, ocrModel, disModel, recCriterion, styleRecCriterion, ocrCriterion, valid_loader, converter, opt)
                model.train()
                if not opt.ocrFixed:
                    ocrModel.train()
                else:
                #     ocrModel.module.Transformation.eval()
                #     ocrModel.module.FeatureExtraction.eval()
                #     ocrModel.module.AdaptiveAvgPool.eval()
                    ocrModel.module.SequenceModeling.train()
                #     ocrModel.module.Prediction.eval()

                disModel.train()

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}] Train OCR loss: {loss_avg_ocr.val():0.5f}, Train Synth loss: {loss_avg.val():0.5f}, Train Dis loss: {loss_avg_dis.val():0.5f}, Valid OCR loss: {valid_loss[0]:0.5f}, Valid Synth loss: {valid_loss[1]:0.5f}, Valid Dis loss: {valid_loss[2]:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                

                current_model_log_ocr = f'{"Current_accuracy_OCR":17s}: {current_accuracy[0]:0.3f}, {"Current_norm_ED_OCR":17s}: {current_norm_ED[0]:0.2f}'
                current_model_log_1 = f'{"Current_accuracy_recon":17s}: {current_accuracy[1]:0.3f}, {"Current_norm_ED_recon":17s}: {current_norm_ED[1]:0.2f}'
                current_model_log_2 = f'{"Current_accuracy_pair":17s}: {current_accuracy[2]:0.3f}, {"Current_norm_ED_pair":17s}: {current_norm_ED[2]:0.2f}'
                
                #plotting
                lib.plot.plot(os.path.join(plotDir,'Train-OCR-Loss'), loss_avg_ocr.val().item())
                lib.plot.plot(os.path.join(plotDir,'Train-Synth-Loss'), loss_avg.val().item())
                lib.plot.plot(os.path.join(plotDir,'Train-Dis-Loss'), loss_avg_dis.val().item())
                
                lib.plot.plot(os.path.join(plotDir,'Train-OCR-Recon1-Loss'), loss_avg_ocrRecon_1.val().item())
                lib.plot.plot(os.path.join(plotDir,'Train-OCR-Recon2-Loss'), loss_avg_ocrRecon_2.val().item())
                lib.plot.plot(os.path.join(plotDir,'Train-Gen-Loss'), loss_avg_gen.val().item())
                lib.plot.plot(os.path.join(plotDir,'Train-ImgRecon1-Loss'), loss_avg_imgRecon.val().item())
                lib.plot.plot(os.path.join(plotDir,'Train-StyRecon2-Loss'), loss_avg_styRecon.val().item())

                lib.plot.plot(os.path.join(plotDir,'Valid-OCR-Loss'), valid_loss[0].item())
                lib.plot.plot(os.path.join(plotDir,'Valid-Synth-Loss'), valid_loss[1].item())
                lib.plot.plot(os.path.join(plotDir,'Valid-Dis-Loss'), valid_loss[2].item())

                lib.plot.plot(os.path.join(plotDir,'Valid-OCR-Recon1-Loss'), valid_loss[3].item())
                lib.plot.plot(os.path.join(plotDir,'Valid-OCR-Recon2-Loss'), valid_loss[4].item())
                lib.plot.plot(os.path.join(plotDir,'Valid-Gen-Loss'), valid_loss[5].item())
                lib.plot.plot(os.path.join(plotDir,'Valid-ImgRecon1-Loss'), valid_loss[6].item())
                lib.plot.plot(os.path.join(plotDir,'Valid-StyRecon2-Loss'), valid_loss[7].item())

                lib.plot.plot(os.path.join(plotDir,'Orig-OCR-WordAccuracy'), current_accuracy[0])
                lib.plot.plot(os.path.join(plotDir,'Recon-OCR-WordAccuracy'), current_accuracy[1])
                lib.plot.plot(os.path.join(plotDir,'Pair-OCR-WordAccuracy'), current_accuracy[2])

                lib.plot.plot(os.path.join(plotDir,'Orig-OCR-CharAccuracy'), current_norm_ED[0])
                lib.plot.plot(os.path.join(plotDir,'Recon-OCR-CharAccuracy'), current_norm_ED[1])
                lib.plot.plot(os.path.join(plotDir,'Pair-OCR-CharAccuracy'), current_norm_ED[2])
                

                # keep best accuracy model (on valid dataset)
                if current_accuracy[1] > best_accuracy:
                    best_accuracy = current_accuracy[1]
                    torch.save(model.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_accuracy.pth'))
                    torch.save(disModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_accuracy_dis.pth'))
                if current_norm_ED[1] > best_norm_ED:
                    best_norm_ED = current_norm_ED[1]
                    torch.save(model.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_norm_ED.pth'))
                    torch.save(disModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_norm_ED_dis.pth'))
                best_model_log = f'{"Best_accuracy_Recon":17s}: {best_accuracy:0.3f}, {"Best_norm_ED_Recon":17s}: {best_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy[0] > best_accuracy_ocr:
                    best_accuracy_ocr = current_accuracy[0]
                    if not opt.ocrFixed:
                        torch.save(ocrModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_accuracy_ocr.pth'))
                if current_norm_ED[0] > best_norm_ED_ocr:
                    best_norm_ED_ocr = current_norm_ED[0]
                    if not opt.ocrFixed:
                        torch.save(ocrModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'best_norm_ED_ocr.pth'))
                best_model_log_ocr = f'{"Best_accuracy_ocr":17s}: {best_accuracy_ocr:0.3f}, {"Best_norm_ED_ocr":17s}: {best_norm_ED_ocr:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log_ocr}\n{current_model_log_1}\n{current_model_log_2}\n{best_model_log_ocr}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":32s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                
                for gt_ocr, pred_ocr, confidence_ocr, gt_1, pred_1, confidence_1, gt_2, pred_2, confidence_2 in zip(labels[0][:5], preds[0][:5], confidence_score[0][:5], labels[1][:5], preds[1][:5], confidence_score[1][:5], labels[2][:5], preds[2][:5], confidence_score[2][:5]):
                    if 'Attn' in opt.Prediction:
                        # gt_ocr = gt_ocr[:gt_ocr.find('[s]')]
                        pred_ocr = pred_ocr[:pred_ocr.find('[s]')]

                        # gt_1 = gt_1[:gt_1.find('[s]')]
                        pred_1 = pred_1[:pred_1.find('[s]')]

                        # gt_2 = gt_2[:gt_2.find('[s]')]
                        pred_2 = pred_2[:pred_2.find('[s]')]

                    predicted_result_log += f'{"ocr"}: {gt_ocr:27s} | {pred_ocr:25s} | {confidence_ocr:0.4f}\t{str(pred_ocr == gt_ocr)}\n'
                    predicted_result_log += f'{"recon"}: {gt_1:25s} | {pred_1:25s} | {confidence_1:0.4f}\t{str(pred_1 == gt_1)}\n'
                    predicted_result_log += f'{"pair"}: {gt_2:26s} | {pred_2:25s} | {confidence_2:0.4f}\t{str(pred_2 == gt_2)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')

                loss_avg_ocr.reset()
                loss_avg.reset()
                loss_avg_dis.reset()

                loss_avg_ocrRecon_1.reset()
                loss_avg_ocrRecon_2.reset()
                loss_avg_gen.reset()
                loss_avg_imgRecon.reset()
                loss_avg_styRecon.reset()

            lib.plot.flush()

        lib.plot.tick()

        # save model per 1e+5 iter.
        if (iteration) % 1e+5 == 0:
            torch.save(
                model.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_synth.pth'))
            if not opt.ocrFixed:
                torch.save(
                    ocrModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_ocr.pth'))
            torch.save(
                disModel.state_dict(), os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_dis.pth'))

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
Esempio n. 24
0
def validation(model, criterion, eval_loader, converter, opt, tqdm_position=1):
    """validation or evaluation"""
    n_correct = 0
    norm_ED = 0
    length_of_data = 0
    infer_time = 0
    valid_loss_avg = Averager()

    for i, (image_tensors, labels) in tqdm(
            enumerate(eval_loader),
            total=len(eval_loader),
            position=tqdm_position,
            leave=False,
    ):
        batch_size = image_tensors.size(0)
        length_of_data = length_of_data + batch_size
        image = image_tensors.to(device)
        # For max length prediction
        labels_index, labels_length = converter.encode(
            labels, batch_max_length=opt.batch_max_length)

        if "CTC" in opt.Prediction:
            start_time = time.time()
            preds = model(image)
            forward_time = time.time() - start_time

            # Calculate evaluation loss for CTC deocder.
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            # permute 'preds' to use CTCloss format
            cost = criterion(
                preds.log_softmax(2).permute(1, 0, 2),
                labels_index,
                preds_size,
                labels_length,
            )

        else:
            text_for_pred = (torch.LongTensor(batch_size).fill_(
                converter.dict["[SOS]"]).to(device))

            start_time = time.time()
            preds = model(image, text_for_pred, is_train=False)
            forward_time = time.time() - start_time

            target = labels_index[:, 1:]  # without [SOS] Symbol
            cost = criterion(
                preds.contiguous().view(-1, preds.shape[-1]),
                target.contiguous().view(-1),
            )

        # select max probabilty (greedy decoding) then decode index to character
        _, preds_index = preds.max(2)
        preds_size = torch.IntTensor([preds.size(1)] *
                                     preds_index.size(0)).to(device)
        preds_str = converter.decode(preds_index, preds_size)

        infer_time += forward_time
        valid_loss_avg.add(cost)

        # calculate accuracy & confidence score
        preds_prob = F.softmax(preds, dim=2)
        preds_max_prob, _ = preds_prob.max(dim=2)
        confidence_score_list = []
        for gt, prd, prd_max_prob in zip(labels, preds_str, preds_max_prob):
            if "Attn" in opt.Prediction:
                prd_EOS = prd.find("[EOS]")
                prd = prd[:
                          prd_EOS]  # prune after "end of sentence" token ([EOS])
                prd_max_prob = prd_max_prob[:prd_EOS]
            """
            In our experiment, if the model predicts at least one [UNK] token, we count the word prediction as incorrect.
            To not take account of [UNK] token, use the below line.
            prd = prd.replace('[UNK]', '') 
            """

            # To evaluate 'case sensitive model' with alphanumeric and case insensitve setting. = same with ASTER
            gt = gt.lower()
            prd = prd.lower()
            alphanumeric_case_insensitve = "0123456789abcdefghijklmnopqrstuvwxyz"
            out_of_alphanumeric_case_insensitve = f"[^{alphanumeric_case_insensitve}]"
            gt = re.sub(out_of_alphanumeric_case_insensitve, "", gt)
            prd = re.sub(out_of_alphanumeric_case_insensitve, "", prd)

            if opt.NED:
                # ICDAR2019 Normalized Edit Distance
                if len(gt) == 0 or len(prd) == 0:
                    norm_ED += 0
                elif len(gt) > len(prd):
                    norm_ED += 1 - edit_distance(prd, gt) / len(gt)
                else:
                    norm_ED += 1 - edit_distance(prd, gt) / len(prd)

            else:
                if prd == gt:
                    n_correct += 1

            # calculate confidence score (= multiply of prd_max_prob)
            try:
                confidence_score = prd_max_prob.cumprod(dim=0)[-1]
            except:
                confidence_score = 0  # for empty pred case, when prune after "end of sentence" token ([EOS])
            confidence_score_list.append(confidence_score)

    if opt.NED:
        # ICDAR2019 Normalized Edit Distance. In web page, they report % of norm_ED (= norm_ED * 100).
        score = norm_ED / float(length_of_data) * 100
    else:
        score = n_correct / float(length_of_data) * 100  # accuracy

    return (
        valid_loss_avg.val(),
        score,
        preds_str,
        confidence_score_list,
        labels,
        infer_time,
        length_of_data,
    )
Esempio n. 25
0
def validation(model, criterion, evaluation_loader, converter, opt):
    """ validation or evaluation """
    n_correct = 0
    norm_ED = 0
    length_of_data = 0
    infer_time = 0
    valid_loss_avg = Averager()

    for i, (image_tensors, labels) in enumerate(evaluation_loader):
        batch_size = image_tensors.size(0)
        length_of_data = length_of_data + batch_size
        image = image_tensors.to(device)
        # For max length prediction
        length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                          batch_size).to(device)
        text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                         1).fill_(0).to(device)

        text_for_loss, length_for_loss = converter.encode(
            labels, batch_max_length=opt.batch_max_length)

        start_time = time.time()
        preds = model(image, text_for_pred,
                      is_train=False)  # tensor torch.Size([1, 26, 1024])
        forward_time = time.time() - start_time

        preds = preds[:, :text_for_loss.shape[1] - 1, :]
        target = text_for_loss[:, 1:]  # without [GO] Symbol
        cost = criterion(preds.contiguous().view(-1, preds.shape[-1]),
                         target.contiguous().view(-1))

        # select max probabilty (greedy decoding) then decode index to character
        _, preds_index = preds.max(2)
        preds_str = converter.decode(preds_index, length_for_pred)
        labels = converter.decode(text_for_loss[:, 1:], length_for_loss)

        infer_time += forward_time
        valid_loss_avg.add(cost)

        # calculate accuracy & confidence score
        preds_prob = F.softmax(preds, dim=2)
        preds_max_prob, _ = preds_prob.max(dim=2)
        confidence_score_list = []
        for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob):
            gt = gt[:gt.find('[s]')]
            pred_EOS = pred.find('[s]')
            pred = pred[:pred_EOS]  # prune after "end of sentence" token ([s])
            pred_max_prob = pred_max_prob[:pred_EOS]

            # To evaluate 'case sensitive model' with alphanumeric and case insensitve setting.
            if opt.sensitive and opt.data_filtering_off:
                pred = pred.lower()
                gt = gt.lower()
                alphanumeric_case_insensitve = '0123456789abcdefghijklmnopqrstuvwxyz'
                out_of_alphanumeric_case_insensitve = f'[^{alphanumeric_case_insensitve}]'
                pred = re.sub(out_of_alphanumeric_case_insensitve, '', pred)
                gt = re.sub(out_of_alphanumeric_case_insensitve, '', gt)

            if pred == gt:
                n_correct += 1
            '''
            (old version) ICDAR2017 DOST Normalized Edit Distance https://rrc.cvc.uab.es/?ch=7&com=tasks
            "For each word we calculate the normalized edit distance to the length of the ground truth transcription."
            if len(gt) == 0:
                norm_ED += 1
            else:
                norm_ED += edit_distance(pred, gt) / len(gt)
            '''

            # ICDAR2019 Normalized Edit Distance
            if len(gt) == 0 or len(pred) == 0:
                norm_ED += 0
            elif len(gt) > len(pred):
                norm_ED += 1 - edit_distance(pred, gt) / len(gt)
            else:
                norm_ED += 1 - edit_distance(pred, gt) / len(pred)

            # calculate confidence score (= multiply of pred_max_prob)
            try:
                confidence_score = pred_max_prob.cumprod(dim=0)[-1]
            except:
                confidence_score = 0  # for empty pred case, when prune after "end of sentence" token ([s])
            confidence_score_list.append(confidence_score)
            # print(pred, gt, pred==gt, confidence_score)

    accuracy = n_correct / float(length_of_data) * 100
    norm_ED = norm_ED / float(
        length_of_data)  # ICDAR2019 Normalized Edit Distance

    return valid_loss_avg.val(
    ), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data
def train(opt):
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '	')
    log.close()
    """ model configuration """
    # if 'CTC' in opt.Prediction:
    if opt.baiduCTC:
        CTC_converter = CTCLabelConverterForBaiduWarpctc(opt.character)
    else:
        CTC_converter = CTCLabelConverter(opt.character)


# else:
    Attn_converter = AttnLabelConverter(opt.character)
    opt.num_class_ctc = len(CTC_converter.character)
    opt.num_class_attn = len(Attn_converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class_ctc, opt.num_class_attn, opt.batch_max_length,
          opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling,
          opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        # print(name)
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    model = torch.nn.DataParallel(model).to(device)
    model.train()
    print("Model:")
    print(model)
    # print(summary(model, (1, opt.imgH, opt.imgW,1)))
    """ setup loss """
    if opt.baiduCTC:
        # need to install warpctc. see our guideline.
        if opt.label_smooth:
            criterion_major_path = SmoothCTCLoss(num_classes=opt.num_class_ctc,
                                                 weight=0.05)
        else:
            criterion_major_path = CTCLoss()
        #criterion_major_path = CTCLoss(average_frames=False, reduction="mean", blank=0)
    else:
        criterion_major_path = torch.nn.CTCLoss(zero_infinity=True).to(device)
    # else:
    #     criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)  # ignore [GO] token = ignore index 0
    # loss averager
    #criterion_major_path = torch.nn.CTCLoss(zero_infinity=True).to(device)
    criterion_guide_path = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)
    loss_avg_major_path = Averager()
    loss_avg_guide_path = Averager()
    # filter that only require gradient decent
    guide_parameters = []
    major_parameters = []
    guide_model_part_names = [
        "Transformation", "FeatureExtraction", "SequenceModeling_Attn",
        "Attention"
    ]
    major_model_part_names = ["SequenceModeling_CTC", "CTC"]
    for name, param in model.named_parameters():
        if param.requires_grad:
            if name.split(".")[1] in guide_model_part_names:
                guide_parameters.append(param)
            elif name.split(".")[1] in major_model_part_names:
                major_parameters.append(param)
            # print(name)
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]
    if opt.continue_training:
        guide_parameters = []
    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer_ctc = AdamW(major_parameters, lr=opt.lr)
        if not opt.continue_training:
            optimizer_attn = AdamW(guide_parameters, lr=opt.lr)
    scheduler_ctc = get_linear_schedule_with_warmup(
        optimizer_ctc, num_warmup_steps=10000, num_training_steps=opt.num_iter)
    scheduler_attn = get_linear_schedule_with_warmup(
        optimizer_attn,
        num_warmup_steps=10000,
        num_training_steps=opt.num_iter)
    start_iter = 0
    if opt.saved_model != '' and (not opt.continue_training):
        print(f'loading pretrained model from {opt.saved_model}')
        checkpoint = torch.load(opt.saved_model)
        start_iter = checkpoint['start_iter'] + 1
        if not opt.adam:
            optimizer_ctc.load_state_dict(
                checkpoint['optimizer_ctc_state_dict'])
            if not opt.continue_training:
                optimizer_attn.load_state_dict(
                    checkpoint['optimizer_attn_state_dict'])
            scheduler_ctc.load_state_dict(
                checkpoint['scheduler_ctc_state_dict'])
            scheduler_attn.load_state_dict(
                checkpoint['scheduler_attn_state_dict'])
            print(scheduler_ctc.get_lr())
            print(scheduler_attn.get_lr())
        if opt.FT:
            model.load_state_dict(checkpoint['model_state_dict'], strict=False)
        else:
            model.load_state_dict(checkpoint['model_state_dict'])
    if opt.continue_training:
        model.load_state_dict(torch.load(opt.saved_model))
    # print("Optimizer:")
    # print(optimizer)
    #
    scheduler_ctc = get_linear_schedule_with_warmup(
        optimizer_ctc,
        num_warmup_steps=10000,
        num_training_steps=opt.num_iter,
        last_epoch=start_iter - 1)
    scheduler_attn = get_linear_schedule_with_warmup(
        optimizer_attn,
        num_warmup_steps=10000,
        num_training_steps=opt.num_iter,
        last_epoch=start_iter - 1)
    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file:
        opt_log = '------------ Options -------------	'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}	'
        opt_log += '---------------------------------------	'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    iteration = start_iter - 1
    if opt.continue_training:
        start_iter = 0
    while (True):
        # train part
        image_tensors, labels = train_dataset.get_batch()
        iteration += 1
        if iteration < start_iter:
            continue
        image = image_tensors.to(device)
        # print(image.size())
        text_attn, length_attn = Attn_converter.encode(
            labels, batch_max_length=opt.batch_max_length)
        #print("1")
        text_ctc, length_ctc = CTC_converter.encode(
            labels, batch_max_length=opt.batch_max_length)
        #print("2")
        #if iteration == start_iter :
        #    writer.add_graph(model, (image, text_attn))
        batch_size = image.size(0)
        preds_major, preds_guide = model(image, text_attn[:, :-1])
        #print("10")
        preds_size = torch.IntTensor([preds_major.size(1)] * batch_size)
        if opt.baiduCTC:
            preds_major = preds_major.permute(1, 0, 2)  # to use CTCLoss format
            if opt.label_smooth:
                cost_ctc = criterion_major_path(preds_major, text_ctc,
                                                preds_size, length_ctc,
                                                batch_size)
            else:
                cost_ctc = criterion_major_path(
                    preds_major, text_ctc, preds_size, length_ctc) / batch_size
        else:
            preds_major = preds_major.log_softmax(2).permute(1, 0, 2)
            cost_ctc = criterion_major_path(preds_major, text_ctc, preds_size,
                                            length_ctc)
        #print("3")
        # preds = model(image, text[:, :-1])  # align with Attention.forward
        target = text_attn[:, 1:]  # without [GO] Symbol
        if not opt.continue_training:
            cost_attn = criterion_guide_path(
                preds_guide.view(-1, preds_guide.shape[-1]),
                target.contiguous().view(-1))
            optimizer_attn.zero_grad()
            cost_attn.backward(retain_graph=True)
            torch.nn.utils.clip_grad_norm_(
                guide_parameters,
                opt.grad_clip)  # gradient clipping with 5 (Default)
            optimizer_attn.step()
        optimizer_ctc.zero_grad()
        cost_ctc.backward()
        torch.nn.utils.clip_grad_norm_(
            major_parameters,
            opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer_ctc.step()
        scheduler_ctc.step()
        scheduler_attn.step()
        #print("4")
        loss_avg_major_path.add(cost_ctc)
        if not opt.continue_training:
            loss_avg_guide_path.add(cost_attn)
        if (iteration + 1) % 100 == 0:
            writer.add_scalar("Loss/train_ctc", loss_avg_major_path.val(),
                              (iteration + 1) // 100)
            loss_avg_major_path.reset()
            if not opt.continue_training:
                writer.add_scalar("Loss/train_attn", loss_avg_guide_path.val(),
                                  (iteration + 1) // 100)
                loss_avg_guide_path.reset()
        # validation part
        if (
                iteration + 1
        ) % opt.valInterval == 0:  #or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0'
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.exp_name}/log_train.txt',
                      'a') as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, criterion_major_path, valid_loader,
                        CTC_converter, opt)
                model.train()
                writer.add_scalar("Loss/valid", valid_loss,
                                  (iteration + 1) // opt.valInterval)
                writer.add_scalar("Metrics/accuracy", current_accuracy,
                                  (iteration + 1) // opt.valInterval)
                writer.add_scalar("Metrics/norm_ED", current_norm_ED,
                                  (iteration + 1) // opt.valInterval)
                # loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {train_loss:0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                # loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'
                # training loss and validation loss
                if not opt.continue_training:
                    loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss ctc: {loss_avg_major_path.val():0.5f}, Train loss attn: {loss_avg_guide_path.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                else:
                    loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss ctc: {loss_avg_major_path.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                loss_avg_major_path.reset()
                if not opt.continue_training:
                    loss_avg_guide_path.reset()
                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(model.state_dict(),
                               f'{fol_ckpt}/best_accuracy.pth')
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(model.state_dict(),
                               f'{fol_ckpt}/best_norm_ED.pth')
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f'{loss_log}	{current_model_log}	{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '	')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}	{head}	{dashed_line}	'
                for gt, pred, confidence in zip(labels[:5], preds[:5],
                                                confidence_score[:5]):
                    # if 'Attn' in opt.Prediction:
                    #     gt = gt[:gt.find('[s]')]
                    #     pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}	{str(pred == gt)}	'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '	')

        # save model per 1e+5 iter.
        if (iteration + 1) % 1e+3 == 0 and (not opt.continue_training):
            # print(scheduler_ctc.get_lr())
            # print(scheduler_attn.get_lr())
            torch.save(
                {
                    'model_state_dict': model.state_dict(),
                    'optimizer_attn_state_dict': optimizer_attn.state_dict(),
                    'optimizer_ctc_state_dict': optimizer_ctc.state_dict(),
                    'start_iter': iteration,
                    'scheduler_ctc_state_dict': scheduler_ctc.state_dict(),
                    'scheduler_attn_state_dict': scheduler_attn.state_dict(),
                }, f'{fol_ckpt}/current_model.pth')

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
Esempio n. 27
0
def train(opt):
    lib.print_model_settings(locals().copy())

    """ dataset preparation """
    if not opt.data_filtering_off:
        print('Filtering the images containing characters which are not in opt.character')
        print('Filtering the images whose label is longer than opt.batch_max_length')
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    log = open(os.path.join(opt.exp_dir,opt.exp_name,'log_dataset.txt'), 'a')
    AlignCollate_valid = AlignPairCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)

    train_dataset, train_dataset_log = hierarchical_dataset(root=opt.train_data, opt=opt)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=opt.batch_size,
        sampler=data_sampler(train_dataset, shuffle=True, distributed=opt.distributed),
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True)
    log.write(train_dataset_log)
    print('-' * 80)

    valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=opt.batch_size,
        sampler=data_sampler(train_dataset, shuffle=False, distributed=opt.distributed),
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()

    if 'Attn' in opt.Prediction:
        converter = AttnLabelConverter(opt.character)
    else:
        converter = CTCLabelConverter(opt.character)
    
    opt.num_class = len(converter.character)

    
    # styleModel = StyleTensorEncoder(input_dim=opt.input_channel)
    # genModel = AdaIN_Tensor_WordGenerator(opt)
    # disModel = MsImageDisV2(opt)

    # styleModel = StyleLatentEncoder(input_dim=opt.input_channel, norm='none')
    # mixModel = Mixer(opt,nblk=3, dim=opt.latent)
    genModel = styleGANGen(opt.size, opt.latent, opt.n_mlp, opt.num_class, channel_multiplier=opt.channel_multiplier).to(device)
    disModel = styleGANDis(opt.size, channel_multiplier=opt.channel_multiplier, input_dim=opt.input_channel).to(device)
    g_ema = styleGANGen(opt.size, opt.latent, opt.n_mlp, opt.num_class, channel_multiplier=opt.channel_multiplier).to(device)
    ocrModel = ModelV1(opt).to(device)
    accumulate(g_ema, genModel, 0)

    # #  weight initialization
    # for currModel in [styleModel, mixModel]:
    #     for name, param in currModel.named_parameters():
    #         if 'localization_fc2' in name:
    #             print(f'Skip {name} as it is already initialized')
    #             continue
    #         try:
    #             if 'bias' in name:
    #                 init.constant_(param, 0.0)
    #             elif 'weight' in name:
    #                 init.kaiming_normal_(param)
    #         except Exception as e:  # for batchnorm.
    #             if 'weight' in name:
    #                 param.data.fill_(1)
    #             continue

    if opt.contentLoss == 'vis' or opt.contentLoss == 'seq':
        ocrCriterion = torch.nn.L1Loss()
    else:
        if 'CTC' in opt.Prediction:
            ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
        else:
            ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)  # ignore [GO] token = ignore index 0

    # vggRecCriterion = torch.nn.L1Loss()
    # vggModel = VGGPerceptualLossModel(models.vgg19(pretrained=True), vggRecCriterion)
    
    print('model input parameters', opt.imgH, opt.imgW, opt.input_channel, opt.output_channel,
          opt.hidden_size, opt.num_class, opt.batch_max_length)

    if opt.distributed:
        genModel = torch.nn.parallel.DistributedDataParallel(
            genModel,
            device_ids=[opt.local_rank],
            output_device=opt.local_rank,
            broadcast_buffers=False,
        )
        
        disModel = torch.nn.parallel.DistributedDataParallel(
            disModel,
            device_ids=[opt.local_rank],
            output_device=opt.local_rank,
            broadcast_buffers=False,
        )
        ocrModel = torch.nn.parallel.DistributedDataParallel(
            ocrModel,
            device_ids=[opt.local_rank],
            output_device=opt.local_rank,
            broadcast_buffers=False
        )
    
    # styleModel = torch.nn.DataParallel(styleModel).to(device)
    # styleModel.train()
    
    # mixModel = torch.nn.DataParallel(mixModel).to(device)
    # mixModel.train()
    
    # genModel = torch.nn.DataParallel(genModel).to(device)
    # g_ema = torch.nn.DataParallel(g_ema).to(device)
    genModel.train()
    g_ema.eval()

    # disModel = torch.nn.DataParallel(disModel).to(device)
    disModel.train()

    # vggModel = torch.nn.DataParallel(vggModel).to(device)
    # vggModel.eval()

    # ocrModel = torch.nn.DataParallel(ocrModel).to(device)
    # if opt.distributed:
    #     ocrModel.module.Transformation.eval()
    #     ocrModel.module.FeatureExtraction.eval()
    #     ocrModel.module.AdaptiveAvgPool.eval()
    #     # ocrModel.module.SequenceModeling.eval()
    #     ocrModel.module.Prediction.eval()
    # else:
    #     ocrModel.Transformation.eval()
    #     ocrModel.FeatureExtraction.eval()
    #     ocrModel.AdaptiveAvgPool.eval()
    #     # ocrModel.SequenceModeling.eval()
    #     ocrModel.Prediction.eval()
    ocrModel.eval()

    if opt.distributed:
        g_module = genModel.module
        d_module = disModel.module
    else:
        g_module = genModel
        d_module = disModel

    g_reg_ratio = opt.g_reg_every / (opt.g_reg_every + 1)
    d_reg_ratio = opt.d_reg_every / (opt.d_reg_every + 1)

    optimizer = optim.Adam(
        genModel.parameters(),
        lr=opt.lr * g_reg_ratio,
        betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
    )
    dis_optimizer = optim.Adam(
        disModel.parameters(),
        lr=opt.lr * d_reg_ratio,
        betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
    )

    ## Loading pre-trained files
    if opt.modelFolderFlag:
        if len(glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth")))>0:
            opt.saved_synth_model = glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth"))[-1]

    if opt.saved_ocr_model !='' and opt.saved_ocr_model !='None':
        if not opt.distributed:
            ocrModel = torch.nn.DataParallel(ocrModel)
        print(f'loading pretrained ocr model from {opt.saved_ocr_model}')
        checkpoint = torch.load(opt.saved_ocr_model)
        ocrModel.load_state_dict(checkpoint)
        #temporary fix
        if not opt.distributed:
            ocrModel = ocrModel.module
    
    if opt.saved_gen_model !='' and opt.saved_gen_model !='None':
        print(f'loading pretrained gen model from {opt.saved_gen_model}')
        checkpoint = torch.load(opt.saved_gen_model, map_location=lambda storage, loc: storage)
        genModel.module.load_state_dict(checkpoint['g'])
        g_ema.module.load_state_dict(checkpoint['g_ema'])

    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        print(f'loading pretrained synth model from {opt.saved_synth_model}')
        checkpoint = torch.load(opt.saved_synth_model)
        
        # styleModel.load_state_dict(checkpoint['styleModel'])
        # mixModel.load_state_dict(checkpoint['mixModel'])
        genModel.load_state_dict(checkpoint['genModel'])
        g_ema.load_state_dict(checkpoint['g_ema'])
        disModel.load_state_dict(checkpoint['disModel'])
        
        optimizer.load_state_dict(checkpoint["optimizer"])
        dis_optimizer.load_state_dict(checkpoint["dis_optimizer"])

    # if opt.imgReconLoss == 'l1':
    #     recCriterion = torch.nn.L1Loss()
    # elif opt.imgReconLoss == 'ssim':
    #     recCriterion = ssim
    # elif opt.imgReconLoss == 'ms-ssim':
    #     recCriterion = msssim
    

    # loss averager
    loss_avg = Averager()
    loss_avg_dis = Averager()
    loss_avg_gen = Averager()
    loss_avg_imgRecon = Averager()
    loss_avg_vgg_per = Averager()
    loss_avg_vgg_sty = Averager()
    loss_avg_ocr = Averager()

    log_r1_val = Averager()
    log_avg_path_loss_val = Averager()
    log_avg_mean_path_length_avg = Averager()
    log_ada_aug_p = Averager()

    """ final options """
    with open(os.path.join(opt.exp_dir,opt.exp_name,'opt.txt'), 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)

    """ start training """
    start_iter = 0
    
    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        try:
            start_iter = int(opt.saved_synth_model.split('_')[-2].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    
    #get schedulers
    scheduler = get_scheduler(optimizer,opt)
    dis_scheduler = get_scheduler(dis_optimizer,opt)

    start_time = time.time()
    iteration = start_iter
    cntr=0
    
    mean_path_length = 0
    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    accum = 0.5 ** (32 / (10 * 1000))
    ada_augment = torch.tensor([0.0, 0.0], device=device)
    ada_aug_p = opt.augment_p if opt.augment_p > 0 else 0.0
    ada_aug_step = opt.ada_target / opt.ada_length
    r_t_stat = 0

    sample_z = torch.randn(opt.n_sample, opt.latent, device=device)

    while(True):
        # print(cntr)
        # train part
       
        if opt.lr_policy !="None":
            scheduler.step()
            dis_scheduler.step()
        
        image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(train_loader).next()

        image_input_tensors = image_input_tensors.to(device)
        image_gt_tensors = image_gt_tensors.to(device)
        batch_size = image_input_tensors.size(0)

        requires_grad(genModel, False)
        # requires_grad(styleModel, False)
        # requires_grad(mixModel, False)
        requires_grad(disModel, True)

        text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length)
        text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length)
        
        
        #forward pass from style and word generator
        # style = styleModel(image_input_tensors).squeeze(2).squeeze(2)
        style = mixing_noise(opt.batch_size, opt.latent, opt.mixing, device)
        # scInput = mixModel(style,text_2)
        if 'CTC' in opt.Prediction:
            images_recon_2,_ = genModel(style, text_2, input_is_latent=opt.input_latent)
        else:
            images_recon_2,_ = genModel(style, text_2[:,1:-1], input_is_latent=opt.input_latent)
        
        #Domain discriminator: Dis update
        if opt.augment:
            image_gt_tensors_aug, _ = augment(image_gt_tensors, ada_aug_p)
            images_recon_2, _ = augment(images_recon_2, ada_aug_p)

        else:
            image_gt_tensors_aug = image_gt_tensors

        fake_pred = disModel(images_recon_2)
        real_pred = disModel(image_gt_tensors_aug)
        disCost = d_logistic_loss(real_pred, fake_pred)

        loss_dict["d"] = disCost*opt.disWeight
        loss_dict["real_score"] = real_pred.mean()
        loss_dict["fake_score"] = fake_pred.mean()

        loss_avg_dis.add(disCost)

        disModel.zero_grad()
        disCost.backward()
        dis_optimizer.step()

        if opt.augment and opt.augment_p == 0:
            ada_augment += torch.tensor(
                (torch.sign(real_pred).sum().item(), real_pred.shape[0]), device=device
            )
            ada_augment = reduce_sum(ada_augment)

            if ada_augment[1] > 255:
                pred_signs, n_pred = ada_augment.tolist()

                r_t_stat = pred_signs / n_pred

                if r_t_stat > opt.ada_target:
                    sign = 1

                else:
                    sign = -1

                ada_aug_p += sign * ada_aug_step * n_pred
                ada_aug_p = min(1, max(0, ada_aug_p))
                ada_augment.mul_(0)

        d_regularize = cntr % opt.d_reg_every == 0

        if d_regularize:
            image_gt_tensors.requires_grad = True
            image_input_tensors.requires_grad = True
            cat_tensor = image_gt_tensors
            real_pred = disModel(cat_tensor)
            
            r1_loss = d_r1_loss(real_pred, cat_tensor)

            disModel.zero_grad()
            (opt.r1 / 2 * r1_loss * opt.d_reg_every + 0 * real_pred[0]).backward()

            dis_optimizer.step()

        loss_dict["r1"] = r1_loss

        
        # #[Style Encoder] + [Word Generator] update
        image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(train_loader).next()
        
        image_input_tensors = image_input_tensors.to(device)
        image_gt_tensors = image_gt_tensors.to(device)
        batch_size = image_input_tensors.size(0)

        requires_grad(genModel, True)
        # requires_grad(styleModel, True)
        # requires_grad(mixModel, True)
        requires_grad(disModel, False)

        text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length)
        text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length)

        # style = styleModel(image_input_tensors).squeeze(2).squeeze(2)
        # scInput = mixModel(style,text_2)

        # images_recon_2,_ = genModel([scInput], input_is_latent=opt.input_latent)
        style = mixing_noise(batch_size, opt.latent, opt.mixing, device)
        
        if 'CTC' in opt.Prediction:
            images_recon_2, _ = genModel(style, text_2)
        else:
            images_recon_2, _ = genModel(style, text_2[:,1:-1])

        if opt.augment:
            images_recon_2, _ = augment(images_recon_2, ada_aug_p)

        fake_pred = disModel(images_recon_2)
        disGenCost = g_nonsaturating_loss(fake_pred)

        loss_dict["g"] = disGenCost

        # # #Adversarial loss
        # # disGenCost = disModel.module.calc_gen_loss(torch.cat((images_recon_2,image_input_tensors),dim=1))

        # #Input reconstruction loss
        # recCost = recCriterion(images_recon_2,image_gt_tensors)

        # #vgg loss
        # vggPerCost, vggStyleCost = vggModel(image_gt_tensors, images_recon_2)
        #ocr loss
        text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)
        length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
        if opt.contentLoss == 'vis' or opt.contentLoss == 'seq':
            preds_recon = ocrModel(images_recon_2, text_for_pred, is_train=False, returnFeat=opt.contentLoss)
            preds_gt = ocrModel(image_gt_tensors, text_for_pred, is_train=False, returnFeat=opt.contentLoss)
            ocrCost = ocrCriterion(preds_recon, preds_gt)
        else:
            if 'CTC' in opt.Prediction:
                
                preds_recon = ocrModel(images_recon_2, text_for_pred, is_train=False)
                # preds_o = preds_recon[:, :text_1.shape[1], :]
                preds_size = torch.IntTensor([preds_recon.size(1)] * batch_size)
                preds_recon_softmax = preds_recon.log_softmax(2).permute(1, 0, 2)
                ocrCost = ocrCriterion(preds_recon_softmax, text_2, preds_size, length_2)
                
                #predict ocr recognition on generated images
                # preds_recon_size = torch.IntTensor([preds_recon.size(1)] * batch_size)
                _, preds_recon_index = preds_recon.max(2)
                labels_o_ocr = converter.decode(preds_recon_index.data, preds_size.data)

                #predict ocr recognition on gt style images
                preds_s = ocrModel(image_input_tensors, text_for_pred, is_train=False)
                # preds_s = preds_s[:, :text_1.shape[1] - 1, :]
                preds_s_size = torch.IntTensor([preds_s.size(1)] * batch_size)
                _, preds_s_index = preds_s.max(2)
                labels_s_ocr = converter.decode(preds_s_index.data, preds_s_size.data)

                #predict ocr recognition on gt stylecontent images
                preds_sc = ocrModel(image_gt_tensors, text_for_pred, is_train=False)
                # preds_sc = preds_sc[:, :text_2.shape[1] - 1, :]
                preds_sc_size = torch.IntTensor([preds_sc.size(1)] * batch_size)
                _, preds_sc_index = preds_sc.max(2)
                labels_sc_ocr = converter.decode(preds_sc_index.data, preds_sc_size.data)

            else:
                preds_recon = ocrModel(images_recon_2, text_for_pred[:, :-1], is_train=False)  # align with Attention.forward
                target_2 = text_2[:, 1:]  # without [GO] Symbol
                ocrCost = ocrCriterion(preds_recon.view(-1, preds_recon.shape[-1]), target_2.contiguous().view(-1))

                #predict ocr recognition on generated images
                _, preds_o_index = preds_recon.max(2)
                labels_o_ocr = converter.decode(preds_o_index, length_for_pred)
                for idx, pred in enumerate(labels_o_ocr):
                    pred_EOS = pred.find('[s]')
                    labels_o_ocr[idx] = pred[:pred_EOS]  # prune after "end of sentence" token ([s])

                #predict ocr recognition on gt style images
                preds_s = ocrModel(image_input_tensors, text_for_pred, is_train=False)
                _, preds_s_index = preds_s.max(2)
                labels_s_ocr = converter.decode(preds_s_index, length_for_pred)
                for idx, pred in enumerate(labels_s_ocr):
                    pred_EOS = pred.find('[s]')
                    labels_s_ocr[idx] = pred[:pred_EOS]  # prune after "end of sentence" token ([s])
                
                #predict ocr recognition on gt stylecontent images
                preds_sc = ocrModel(image_gt_tensors, text_for_pred, is_train=False)
                _, preds_sc_index = preds_sc.max(2)
                labels_sc_ocr = converter.decode(preds_sc_index, length_for_pred)
                for idx, pred in enumerate(labels_sc_ocr):
                    pred_EOS = pred.find('[s]')
                    labels_sc_ocr[idx] = pred[:pred_EOS]  # prune after "end of sentence" token ([s])

        # cost =  opt.reconWeight*recCost + opt.disWeight*disGenCost + opt.vggPerWeight*vggPerCost + opt.vggStyWeight*vggStyleCost + opt.ocrWeight*ocrCost
        cost =  opt.disWeight*disGenCost + opt.ocrWeight*ocrCost

        # styleModel.zero_grad()
        genModel.zero_grad()
        # mixModel.zero_grad()
        disModel.zero_grad()
        # vggModel.zero_grad()
        ocrModel.zero_grad()
        
        cost.backward()
        optimizer.step()
        loss_avg.add(cost)

        g_regularize = cntr % opt.g_reg_every == 0

        if g_regularize:
            image_input_tensors, image_gt_tensors, labels_1, labels_2 = iter(train_loader).next()
        
            image_input_tensors = image_input_tensors.to(device)
            image_gt_tensors = image_gt_tensors.to(device)
            batch_size = image_input_tensors.size(0)

            text_1, length_1 = converter.encode(labels_1, batch_max_length=opt.batch_max_length)
            text_2, length_2 = converter.encode(labels_2, batch_max_length=opt.batch_max_length)

            path_batch_size = max(1, batch_size // opt.path_batch_shrink)

            # style = styleModel(image_input_tensors).squeeze(2).squeeze(2)
            # scInput = mixModel(style,text_2)

            # images_recon_2, latents = genModel([scInput],input_is_latent=opt.input_latent, return_latents=True)

            style = mixing_noise(path_batch_size, opt.latent, opt.mixing, device)
            
            
            if 'CTC' in opt.Prediction:
                images_recon_2, latents = genModel(style, text_2[:path_batch_size], return_latents=True)
            else:
                images_recon_2, latents = genModel(style, text_2[:path_batch_size,1:-1], return_latents=True)
            
            
            path_loss, mean_path_length, path_lengths = g_path_regularize(
                images_recon_2, latents, mean_path_length
            )

            genModel.zero_grad()
            weighted_path_loss = opt.path_regularize * opt.g_reg_every * path_loss

            if opt.path_batch_shrink:
                weighted_path_loss += 0 * images_recon_2[0, 0, 0, 0]

            weighted_path_loss.backward()

            optimizer.step()

            mean_path_length_avg = (
                reduce_sum(mean_path_length).item() / get_world_size()
            )

        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()


        #Individual losses
        loss_avg_gen.add(opt.disWeight*disGenCost)
        loss_avg_imgRecon.add(torch.tensor(0.0))
        loss_avg_vgg_per.add(torch.tensor(0.0))
        loss_avg_vgg_sty.add(torch.tensor(0.0))
        loss_avg_ocr.add(opt.ocrWeight*ocrCost)

        log_r1_val.add(loss_reduced["path"])
        log_avg_path_loss_val.add(loss_reduced["path"])
        log_avg_mean_path_length_avg.add(torch.tensor(mean_path_length_avg))
        log_ada_aug_p.add(torch.tensor(ada_aug_p))
        
        if get_rank() == 0:
            # pbar.set_description(
            #     (
            #         f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
            #         f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
            #         f"augment: {ada_aug_p:.4f}"
            #     )
            # )

            if wandb and opt.wandb:
                wandb.log(
                    {
                        "Generator": g_loss_val,
                        "Discriminator": d_loss_val,
                        "Augment": ada_aug_p,
                        "Rt": r_t_stat,
                        "R1": r1_val,
                        "Path Length Regularization": path_loss_val,
                        "Mean Path Length": mean_path_length,
                        "Real Score": real_score_val,
                        "Fake Score": fake_score_val,   
                        "Path Length": path_length_val,
                    }
                )
            # if cntr % 100 == 0:
            #     with torch.no_grad():
            #         g_ema.eval()
            #         sample, _ = g_ema([scInput[:,:opt.latent],scInput[:,opt.latent:]])
            #         utils.save_image(
            #             sample,
            #             os.path.join(opt.trainDir, f"sample_{str(cntr).zfill(6)}.png"),
            #             nrow=int(opt.n_sample ** 0.5),
            #             normalize=True,
            #             range=(-1, 1),
            #         )


        # validation part
        if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' 
            
            #Save training images
            curr_batch_size = style[0].shape[0]
            images_recon_2, _ = g_ema(style, text_2[:curr_batch_size], input_is_latent=opt.input_latent)
            
            os.makedirs(os.path.join(opt.trainDir,str(iteration)), exist_ok=True)
            for trImgCntr in range(batch_size):
                try:
                    if opt.contentLoss == 'vis' or opt.contentLoss == 'seq':
                        save_image(tensor2im(image_input_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_sInput_'+labels_1[trImgCntr]+'.png'))
                        save_image(tensor2im(image_gt_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csGT_'+labels_2[trImgCntr]+'.png'))
                        save_image(tensor2im(images_recon_2[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csRecon_'+labels_2[trImgCntr]+'.png'))
                    else:
                        save_image(tensor2im(image_input_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_sInput_'+labels_1[trImgCntr]+'_'+labels_s_ocr[trImgCntr]+'.png'))
                        save_image(tensor2im(image_gt_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csGT_'+labels_2[trImgCntr]+'_'+labels_sc_ocr[trImgCntr]+'.png'))
                        save_image(tensor2im(images_recon_2[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_csRecon_'+labels_2[trImgCntr]+'_'+labels_o_ocr[trImgCntr]+'.png'))
                except:
                    print('Warning while saving training image')
            
            elapsed_time = time.time() - start_time
            # for log
            
            with open(os.path.join(opt.exp_dir,opt.exp_name,'log_train.txt'), 'a') as log:
                # styleModel.eval()
                genModel.eval()
                g_ema.eval()
                # mixModel.eval()
                disModel.eval()
                
                with torch.no_grad():                    
                    valid_loss, infer_time, length_of_data = validation_synth_v6(
                        iteration, g_ema, ocrModel, disModel, ocrCriterion, valid_loader, converter, opt)
                
                # styleModel.train()
                genModel.train()
                # mixModel.train()
                disModel.train()

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}] Train Synth loss: {loss_avg.val():0.5f}, \
                    Train Dis loss: {loss_avg_dis.val():0.5f}, Train Gen loss: {loss_avg_gen.val():0.5f},\
                    Train OCR loss: {loss_avg_ocr.val():0.5f}, \
                    Train R1-val loss: {log_r1_val.val():0.5f}, Train avg-path-loss: {log_avg_path_loss_val.val():0.5f}, \
                    Train mean-path-length loss: {log_avg_mean_path_length_avg.val():0.5f}, Train ada-aug-p: {log_ada_aug_p.val():0.5f}, \
                    Valid Synth loss: {valid_loss[0]:0.5f}, \
                    Valid Dis loss: {valid_loss[1]:0.5f}, Valid Gen loss: {valid_loss[2]:0.5f}, \
                    Valid OCR loss: {valid_loss[6]:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                
                
                #plotting
                lib.plot.plot(os.path.join(opt.plotDir,'Train-Synth-Loss'), loss_avg.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-Dis-Loss'), loss_avg_dis.val().item())
                
                lib.plot.plot(os.path.join(opt.plotDir,'Train-Gen-Loss'), loss_avg_gen.val().item())
                # lib.plot.plot(os.path.join(opt.plotDir,'Train-ImgRecon1-Loss'), loss_avg_imgRecon.val().item())
                # lib.plot.plot(os.path.join(opt.plotDir,'Train-VGG-Per-Loss'), loss_avg_vgg_per.val().item())
                # lib.plot.plot(os.path.join(opt.plotDir,'Train-VGG-Sty-Loss'), loss_avg_vgg_sty.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-OCR-Loss'), loss_avg_ocr.val().item())

                lib.plot.plot(os.path.join(opt.plotDir,'Train-r1_val'), log_r1_val.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-path_loss_val'), log_avg_path_loss_val.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-mean_path_length_avg'), log_avg_mean_path_length_avg.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-ada_aug_p'), log_ada_aug_p.val().item())

                lib.plot.plot(os.path.join(opt.plotDir,'Valid-Synth-Loss'), valid_loss[0].item())
                lib.plot.plot(os.path.join(opt.plotDir,'Valid-Dis-Loss'), valid_loss[1].item())

                lib.plot.plot(os.path.join(opt.plotDir,'Valid-Gen-Loss'), valid_loss[2].item())
                # lib.plot.plot(os.path.join(opt.plotDir,'Valid-ImgRecon1-Loss'), valid_loss[3].item())
                # lib.plot.plot(os.path.join(opt.plotDir,'Valid-VGG-Per-Loss'), valid_loss[4].item())
                # lib.plot.plot(os.path.join(opt.plotDir,'Valid-VGG-Sty-Loss'), valid_loss[5].item())
                lib.plot.plot(os.path.join(opt.plotDir,'Valid-OCR-Loss'), valid_loss[6].item())
                
                print(loss_log)

                loss_avg.reset()
                loss_avg_dis.reset()

                loss_avg_gen.reset()
                loss_avg_imgRecon.reset()
                loss_avg_vgg_per.reset()
                loss_avg_vgg_sty.reset()
                loss_avg_ocr.reset()

                log_r1_val.reset()
                log_avg_path_loss_val.reset()
                log_avg_mean_path_length_avg.reset()
                log_ada_aug_p.reset()
                

            lib.plot.flush()

        lib.plot.tick()

        # save model per 1e+5 iter.
        if (iteration) % 1e+4 == 0:
            torch.save({
                # 'styleModel':styleModel.state_dict(),
                # 'mixModel':mixModel.state_dict(),
                'genModel':g_module.state_dict(),
                'g_ema':g_ema.state_dict(),
                'disModel':d_module.state_dict(),
                'optimizer':optimizer.state_dict(),
                'dis_optimizer':dis_optimizer.state_dict()}, 
                os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_synth.pth'))
            

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
        cntr+=1
Esempio n. 28
0
def validation(model, criterion, evaluation_loader, converter, opt):
    """ validation or evaluation """
    n_correct = 0
    norm_ED = 0
    length_of_data = 0
    infer_time = 0
    valid_loss_avg = Averager()

    for i, (image_tensors, labels) in enumerate(evaluation_loader):
        batch_size = image_tensors.size(0)
        length_of_data = length_of_data + batch_size
        image = image_tensors.to(device)
        # For max length prediction
        length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                          batch_size).to(device)
        text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                         1).fill_(0).to(device)

        text_for_loss, length_for_loss = converter.encode(
            labels, batch_max_length=opt.batch_max_length)

        start_time = time.time()
        if 'CTC' in opt.Prediction:
            preds = model(image, text_for_pred)
            forward_time = time.time() - start_time

            # Calculate evaluation loss for CTC deocder.
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            # permute 'preds' to use CTCloss format
            if opt.baiduCTC:
                cost = criterion(preds.permute(1, 0, 2), text_for_loss,
                                 preds_size, length_for_loss) / batch_size
            else:
                cost = criterion(
                    preds.log_softmax(2).permute(1, 0, 2), text_for_loss,
                    preds_size, length_for_loss)

            # Select max probabilty (greedy decoding) then decode index to character
            if opt.baiduCTC:
                _, preds_index = preds.max(2)
                preds_index = preds_index.view(-1)
            else:
                _, preds_index = preds.max(2)
            preds_str = converter.decode(preds_index.data, preds_size.data)

        else:
            preds = model(image, text_for_pred, is_train=False)
            forward_time = time.time() - start_time

            preds = preds[:, :text_for_loss.shape[1] - 1, :]
            target = text_for_loss[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.contiguous().view(-1, preds.shape[-1]),
                             target.contiguous().view(-1))

            # select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            preds_str = converter.decode(preds_index, length_for_pred)
            labels = converter.decode(text_for_loss[:, 1:], length_for_loss)

        infer_time += forward_time
        valid_loss_avg.add(cost)

        # calculate accuracy & confidence score
        preds_prob = F.softmax(preds, dim=2)
        preds_max_prob, _ = preds_prob.max(dim=2)
        confidence_score_list = []
        for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob):
            if 'Attn' in opt.Prediction:
                gt = gt[:gt.find('[s]')]
                pred_EOS = pred.find('[s]')
                pred = pred[:
                            pred_EOS]  # prune after "end of sentence" token ([s])
                pred_max_prob = pred_max_prob[:pred_EOS]

            # To evaluate 'case sensitive model' with alphanumeric and case insensitve setting.
            if opt.sensitive and opt.data_filtering_off:
                pred = pred.lower()
                gt = gt.lower()
                alphanumeric_case_insensitve = ' 0123456789가각간갇갈감갑값갓강갖같갚갛개객걀걔거걱건걷걸검겁것겉게겨격겪견결겹경곁계고곡곤곧골곰곱곳공과관광괜괴굉교구국군굳굴굵굶굽궁권귀귓규균귤그극근글긁금급긋긍기긴길김깅깊까깍깎깐깔깜깝깡깥깨꺼꺾껌껍껏껑께껴꼬꼭꼴꼼꼽꽂꽃꽉꽤꾸꾼꿀꿈뀌끄끈끊끌끓끔끗끝끼낌나낙낚난날낡남납낫낭낮낯낱낳내냄냇냉냐냥너넉넌널넓넘넣네넥넷녀녁년념녕노녹논놀놈농높놓놔뇌뇨누눈눕뉘뉴늄느늑는늘늙능늦늬니닐님다닥닦단닫달닭닮담답닷당닿대댁댐댓더덕던덜덟덤덥덧덩덮데델도독돈돌돕돗동돼되된두둑둘둠둡둥뒤뒷드득든듣들듬듭듯등디딩딪따딱딴딸땀땅때땜떠떡떤떨떻떼또똑뚜뚫뚱뛰뜨뜩뜯뜰뜻띄라락란람랍랑랗래랜램랫략량러럭런럴럼럽럿렁렇레렉렌려력련렬렵령례로록론롬롭롯료루룩룹룻뤄류륙률륭르른름릇릎리릭린림립릿링마막만많말맑맘맙맛망맞맡맣매맥맨맵맺머먹먼멀멈멋멍멎메멘멩며면멸명몇모목몬몰몸몹못몽묘무묵묶문묻물뭄뭇뭐뭘뭣므미민믿밀밉밌및밑바박밖반받발밝밟밤밥방밭배백뱀뱃뱉버번벌범법벗베벤벨벼벽변별볍병볕보복볶본볼봄봇봉뵈뵙부북분불붉붐붓붕붙뷰브븐블비빌빔빗빚빛빠빡빨빵빼뺏뺨뻐뻔뻗뼈뼉뽑뿌뿐쁘쁨사삭산살삶삼삿상새색샌생샤서석섞선설섬섭섯성세섹센셈셋셔션소속손솔솜솟송솥쇄쇠쇼수숙순숟술숨숫숭숲쉬쉰쉽슈스슨슬슴습슷승시식신싣실싫심십싯싱싶싸싹싼쌀쌍쌓써썩썰썹쎄쏘쏟쑤쓰쓴쓸씀씌씨씩씬씹씻아악안앉않알앓암압앗앙앞애액앨야약얀얄얇양얕얗얘어억언얹얻얼엄업없엇엉엊엌엎에엔엘여역연열엷염엽엿영옆예옛오옥온올옮옳옷옹와완왕왜왠외왼요욕용우욱운울움웃웅워원월웨웬위윗유육율으윽은을음응의이익인일읽잃임입잇있잊잎자작잔잖잘잠잡잣장잦재쟁쟤저적전절젊점접젓정젖제젠젯져조족존졸좀좁종좋좌죄주죽준줄줌줍중쥐즈즉즌즐즘증지직진질짐집짓징짙짚짜짝짧째쨌쩌쩍쩐쩔쩜쪽쫓쭈쭉찌찍찢차착찬찮찰참찻창찾채책챔챙처척천철첩첫청체쳐초촉촌촛총촬최추축춘출춤춥춧충취츠측츰층치칙친칠침칫칭카칸칼캄캐캠커컨컬컴컵컷케켓켜코콘콜콤콩쾌쿄쿠퀴크큰클큼키킬타탁탄탈탑탓탕태택탤터턱턴털텅테텍텔템토톤톨톱통퇴투툴툼퉁튀튜트특튼튿틀틈티틱팀팅파팎판팔팝패팩팬퍼퍽페펜펴편펼평폐포폭폰표푸푹풀품풍퓨프플픔피픽필핏핑하학한할함합항해핵핸햄햇행향허헌험헤헬혀현혈협형혜호혹혼홀홈홉홍화확환활황회획횟횡효후훈훌훔훨휘휴흉흐흑흔흘흙흡흥흩희흰히힘.?'
                out_of_alphanumeric_case_insensitve = f'[^{alphanumeric_case_insensitve}]'
                pred = re.sub(out_of_alphanumeric_case_insensitve, '', pred)
                gt = re.sub(out_of_alphanumeric_case_insensitve, '', gt)

            if pred == gt:
                n_correct += 1
            '''
            (old version) ICDAR2017 DOST Normalized Edit Distance https://rrc.cvc.uab.es/?ch=7&com=tasks
            "For each word we calculate the normalized edit distance to the length of the ground truth transcription."
            if len(gt) == 0:
                norm_ED += 1
            else:
                norm_ED += edit_distance(pred, gt) / len(gt)
            '''

            # ICDAR2019 Normalized Edit Distance
            if len(gt) == 0 or len(pred) == 0:
                norm_ED += 0
            elif len(gt) > len(pred):
                norm_ED += 1 - edit_distance(pred, gt) / len(gt)
            else:
                norm_ED += 1 - edit_distance(pred, gt) / len(pred)

            # calculate confidence score (= multiply of pred_max_prob)
            try:
                confidence_score = pred_max_prob.cumprod(dim=0)[-1]
            except:
                confidence_score = 0  # for empty pred case, when prune after "end of sentence" token ([s])
            confidence_score_list.append(confidence_score)
            # print(pred, gt, pred==gt, confidence_score)

    accuracy = n_correct / float(length_of_data) * 100
    norm_ED = norm_ED / float(
        length_of_data)  # ICDAR2019 Normalized Edit Distance

    return valid_loss_avg.val(
    ), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data
def train(opt):
    lib.print_model_settings(locals().copy())

    if 'Attn' in opt.Prediction:
        converter = AttnLabelConverter(opt.character)
        text_len = opt.batch_max_length+2
    else:
        converter = CTCLabelConverter(opt.character)
        text_len = opt.batch_max_length

    opt.classes = converter.character
    
    """ dataset preparation """
    if not opt.data_filtering_off:
        print('Filtering the images containing characters which are not in opt.character')
        print('Filtering the images whose label is longer than opt.batch_max_length')
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    log = open(os.path.join(opt.exp_dir,opt.exp_name,'log_dataset.txt'), 'a')
    AlignCollate_valid = AlignPairCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)

    train_dataset = LmdbStyleDataset(root=opt.train_data, opt=opt)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=opt.batch_size*2, #*2 to sample different images from training encoder and discriminator real images
        shuffle=True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True)
    
    print('-' * 80)
    
    valid_dataset = LmdbStyleDataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=opt.batch_size*2, #*2 to sample different images from training encoder and discriminator real images
        shuffle=False,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True, drop_last=True)
    
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()

    text_dataset = text_gen(opt)
    text_loader = torch.utils.data.DataLoader(
        text_dataset, batch_size=opt.batch_size,
        shuffle=True,
        num_workers=int(opt.workers),
        pin_memory=True, drop_last=True)
    opt.num_class = len(converter.character)
    

    c_code_size = opt.latent
    cEncoder = GlobalContentEncoder(opt.num_class, text_len, opt.char_embed_size, c_code_size)
    ocrModel = ModelV1(opt)

    
    genModel = styleGANGen(opt.size, opt.latent, opt.latent, opt.n_mlp, channel_multiplier=opt.channel_multiplier)
    g_ema = styleGANGen(opt.size, opt.latent, opt.latent, opt.n_mlp, channel_multiplier=opt.channel_multiplier)
   
    disEncModel = styleGANDis(opt.size, channel_multiplier=opt.channel_multiplier, input_dim=opt.input_channel, code_s_dim=c_code_size)
    
    accumulate(g_ema, genModel, 0)
    
    # uCriterion = torch.nn.MSELoss()
    # sCriterion = torch.nn.MSELoss()
    # if opt.contentLoss == 'vis' or opt.contentLoss == 'seq':
    #     ocrCriterion = torch.nn.L1Loss()
    # else:
    if 'CTC' in opt.Prediction:
        ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        print('Not implemented error')
        sys.exit()
        # ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)  # ignore [GO] token = ignore index 0

    cEncoder= torch.nn.DataParallel(cEncoder).to(device)
    cEncoder.train()
    genModel = torch.nn.DataParallel(genModel).to(device)
    g_ema = torch.nn.DataParallel(g_ema).to(device)
    genModel.train()
    g_ema.eval()

    disEncModel = torch.nn.DataParallel(disEncModel).to(device)
    disEncModel.train()

    ocrModel = torch.nn.DataParallel(ocrModel).to(device)
    if opt.ocrFixed:
        if opt.Transformation == 'TPS':
            ocrModel.module.Transformation.eval()
        ocrModel.module.FeatureExtraction.eval()
        ocrModel.module.AdaptiveAvgPool.eval()
        # ocrModel.module.SequenceModeling.eval()
        ocrModel.module.Prediction.eval()
    else:
        ocrModel.train()

    g_reg_ratio = opt.g_reg_every / (opt.g_reg_every + 1)
    d_reg_ratio = opt.d_reg_every / (opt.d_reg_every + 1)

    
    optimizer = optim.Adam(
        list(genModel.parameters())+list(cEncoder.parameters()),
        lr=opt.lr * g_reg_ratio,
        betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
    )
    dis_optimizer = optim.Adam(
        disEncModel.parameters(),
        lr=opt.lr * d_reg_ratio,
        betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
    )
    
    ocr_optimizer = optim.Adam(
        ocrModel.parameters(),
        lr=opt.lr,
        betas=(0.9, 0.99),
    )


    ## Loading pre-trained files
    if opt.modelFolderFlag:
        if len(glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth")))>0:
            opt.saved_synth_model = glob.glob(os.path.join(opt.exp_dir,opt.exp_name,"iter_*_synth.pth"))[-1]

    if opt.saved_ocr_model !='' and opt.saved_ocr_model !='None':
        print(f'loading pretrained ocr model from {opt.saved_ocr_model}')
        checkpoint = torch.load(opt.saved_ocr_model)
        ocrModel.load_state_dict(checkpoint)
    
    # if opt.saved_gen_model !='' and opt.saved_gen_model !='None':
    #     print(f'loading pretrained gen model from {opt.saved_gen_model}')
    #     checkpoint = torch.load(opt.saved_gen_model, map_location=lambda storage, loc: storage)
    #     genModel.module.load_state_dict(checkpoint['g'])
    #     g_ema.module.load_state_dict(checkpoint['g_ema'])

    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        print(f'loading pretrained synth model from {opt.saved_synth_model}')
        checkpoint = torch.load(opt.saved_synth_model)
        
        # styleModel.load_state_dict(checkpoint['styleModel'])
        # mixModel.load_state_dict(checkpoint['mixModel'])
        genModel.load_state_dict(checkpoint['genModel'])
        g_ema.load_state_dict(checkpoint['g_ema'])
        disEncModel.load_state_dict(checkpoint['disEncModel'])
        ocrModel.load_state_dict(checkpoint['ocrModel'])
        
        optimizer.load_state_dict(checkpoint["optimizer"])
        dis_optimizer.load_state_dict(checkpoint["dis_optimizer"])
        ocr_optimizer.load_state_dict(checkpoint["ocr_optimizer"])

    # if opt.imgReconLoss == 'l1':
    #     recCriterion = torch.nn.L1Loss()
    # elif opt.imgReconLoss == 'ssim':
    #     recCriterion = ssim
    # elif opt.imgReconLoss == 'ms-ssim':
    #     recCriterion = msssim
    

    # loss averager
    loss_avg_dis = Averager()
    loss_avg_gen = Averager()
    loss_avg_unsup = Averager()
    loss_avg_sup = Averager()
    log_r1_val = Averager()
    log_avg_path_loss_val = Averager()
    log_avg_mean_path_length_avg = Averager()
    log_ada_aug_p = Averager()
    loss_avg_ocr_sup = Averager()
    loss_avg_ocr_unsup = Averager()

    """ final options """
    with open(os.path.join(opt.exp_dir,opt.exp_name,'opt.txt'), 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)

    """ start training """
    start_iter = 0
    
    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        try:
            start_iter = int(opt.saved_synth_model.split('_')[-2].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    
    #get schedulers
    scheduler = get_scheduler(optimizer,opt)
    dis_scheduler = get_scheduler(dis_optimizer,opt)
    ocr_scheduler = get_scheduler(ocr_optimizer,opt)

    start_time = time.time()
    iteration = start_iter
    cntr=0
    
    mean_path_length = 0
    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    # loss_dict = {}

    accum = 0.5 ** (32 / (10 * 1000))
    ada_augment = torch.tensor([0.0, 0.0], device=device)
    ada_aug_p = opt.augment_p if opt.augment_p > 0 else 0.0
    ada_aug_step = opt.ada_target / opt.ada_length
    r_t_stat = 0
    epsilon = 10e-50
    # sample_z = torch.randn(opt.n_sample, opt.latent, device=device)

    while(True):
        # print(cntr)
        # train part
        if opt.lr_policy !="None":
            scheduler.step()
            dis_scheduler.step()
            ocr_scheduler.step()
        
        image_input_tensors, _, labels, _ = iter(train_loader).next()
        labels_z_c = iter(text_loader).next()

        image_input_tensors = image_input_tensors.to(device)
        gt_image_tensors = image_input_tensors[:opt.batch_size].detach()
        real_image_tensors = image_input_tensors[opt.batch_size:].detach()
        
        labels_gt = labels[:opt.batch_size]
        
        requires_grad(cEncoder, False)
        requires_grad(genModel, False)
        requires_grad(disEncModel, True)
        requires_grad(ocrModel, False)

        text_z_c, length_z_c = converter.encode(labels_z_c, batch_max_length=opt.batch_max_length)
        text_gt, length_gt = converter.encode(labels_gt, batch_max_length=opt.batch_max_length)

        z_c_code = cEncoder(text_z_c)
        noise_style = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device)
        style=[]
        style.append(noise_style[0]*z_c_code)
        if len(noise_style)>1:
            style.append(noise_style[1]*z_c_code)
        
        if opt.zAlone:
            #to validate orig style gan results
            newstyle = []
            newstyle.append(style[0][:,:opt.latent])
            if len(style)>1:
                newstyle.append(style[1][:,:opt.latent])
            style = newstyle
        
        fake_img,_ = genModel(style, input_is_latent=opt.input_latent)
        
        # #unsupervised code prediction on generated image
        # u_pred_code = disEncModel(fake_img, mode='enc')
        # uCost = uCriterion(u_pred_code, z_code)

        # #supervised code prediction on gt image
        # s_pred_code = disEncModel(gt_image_tensors, mode='enc')
        # sCost = uCriterion(s_pred_code, gt_phoc_tensors)

        #Domain discriminator
        fake_pred = disEncModel(fake_img)
        real_pred = disEncModel(real_image_tensors)
        disCost = d_logistic_loss(real_pred, fake_pred)

        # dis_cost = disCost + opt.gamma_e*uCost + opt.beta*sCost
        loss_avg_dis.add(disCost)
        # loss_avg_sup.add(opt.beta*sCost)
        # loss_avg_unsup.add(opt.gamma_e * uCost)

        disEncModel.zero_grad()
        disCost.backward()
        dis_optimizer.step()

        d_regularize = cntr % opt.d_reg_every == 0

        if d_regularize:
            real_image_tensors.requires_grad = True
            real_pred = disEncModel(real_image_tensors)
            
            r1_loss = d_r1_loss(real_pred, real_image_tensors)

            disEncModel.zero_grad()
            (opt.r1 / 2 * r1_loss * opt.d_reg_every + 0 * real_pred[0]).backward()

            dis_optimizer.step()
        log_r1_val.add(r1_loss)
        
        # Recognizer update
        if not opt.ocrFixed and not opt.zAlone:
            requires_grad(disEncModel, False)
            requires_grad(ocrModel, True)

            if 'CTC' in opt.Prediction:
                preds_recon = ocrModel(gt_image_tensors, text_gt, is_train=True)
                preds_size = torch.IntTensor([preds_recon.size(1)] * opt.batch_size)
                preds_recon_softmax = preds_recon.log_softmax(2).permute(1, 0, 2)
                ocrCost = ocrCriterion(preds_recon_softmax, text_gt, preds_size, length_gt)
            else:
                print("Not implemented error")
                sys.exit()
            
            ocrModel.zero_grad()
            ocrCost.backward()
            # torch.nn.utils.clip_grad_norm_(ocrModel.parameters(), opt.grad_clip)  # gradient clipping with 5 (Default)
            ocr_optimizer.step()
            loss_avg_ocr_sup.add(ocrCost)
        else:
            loss_avg_ocr_sup.add(torch.tensor(0.0))


        # [Word Generator] update
        # image_input_tensors, _, labels, _ = iter(train_loader).next()
        labels_z_c = iter(text_loader).next()

        # image_input_tensors = image_input_tensors.to(device)
        # gt_image_tensors = image_input_tensors[:opt.batch_size]
        # real_image_tensors = image_input_tensors[opt.batch_size:]
        
        # labels_gt = labels[:opt.batch_size]

        requires_grad(cEncoder, True)
        requires_grad(genModel, True)
        requires_grad(disEncModel, False)
        requires_grad(ocrModel, False)

        text_z_c, length_z_c = converter.encode(labels_z_c, batch_max_length=opt.batch_max_length)
        
        z_c_code = cEncoder(text_z_c)
        noise_style = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device)
        style=[]
        style.append(noise_style[0]*z_c_code)
        if len(noise_style)>1:
            style.append(noise_style[1]*z_c_code)

        if opt.zAlone:
            #to validate orig style gan results
            newstyle = []
            newstyle.append(style[0][:,:opt.latent])
            if len(style)>1:
                newstyle.append(style[1][:,:opt.latent])
            style = newstyle
        
        fake_img,_ = genModel(style, input_is_latent=opt.input_latent)

        fake_pred = disEncModel(fake_img)
        disGenCost = g_nonsaturating_loss(fake_pred)

        if opt.zAlone:
            ocrCost = torch.tensor(0.0)
        else:
            #Compute OCR prediction (Reconstruction of content)
            # text_for_pred = torch.LongTensor(opt.batch_size, opt.batch_max_length + 1).fill_(0).to(device)
            # length_for_pred = torch.IntTensor([opt.batch_max_length] * opt.batch_size).to(device)
            
            if 'CTC' in opt.Prediction:
                preds_recon = ocrModel(fake_img, text_z_c, is_train=False)
                preds_size = torch.IntTensor([preds_recon.size(1)] * opt.batch_size)
                preds_recon_softmax = preds_recon.log_softmax(2).permute(1, 0, 2)
                ocrCost = ocrCriterion(preds_recon_softmax, text_z_c, preds_size, length_z_c)
            else:
                print("Not implemented error")
                sys.exit()
        
        genModel.zero_grad()
        cEncoder.zero_grad()

        gen_enc_cost = disGenCost + opt.ocrWeight * ocrCost
        grad_fake_OCR = torch.autograd.grad(ocrCost, fake_img, retain_graph=True)[0]
        loss_grad_fake_OCR = 10**6*torch.mean(grad_fake_OCR**2)
        grad_fake_adv = torch.autograd.grad(disGenCost, fake_img, retain_graph=True)[0]
        loss_grad_fake_adv = 10**6*torch.mean(grad_fake_adv**2)
        
        if opt.grad_balance:
            gen_enc_cost.backward(retain_graph=True)
            grad_fake_OCR = torch.autograd.grad(ocrCost, fake_img, create_graph=True, retain_graph=True)[0]
            grad_fake_adv = torch.autograd.grad(disGenCost, fake_img, create_graph=True, retain_graph=True)[0]
            a = opt.ocrWeight * torch.div(torch.std(grad_fake_adv), epsilon+torch.std(grad_fake_OCR))
            if a is None:
                print(ocrCost, disGenCost, torch.std(grad_fake_adv), torch.std(grad_fake_OCR))
            if a>1000 or a<0.0001:
                print(a)
            
            ocrCost = a.detach() * ocrCost
            gen_enc_cost = disGenCost + ocrCost
            gen_enc_cost.backward(retain_graph=True)
            grad_fake_OCR = torch.autograd.grad(ocrCost, fake_img, create_graph=False, retain_graph=True)[0]
            grad_fake_adv = torch.autograd.grad(disGenCost, fake_img, create_graph=False, retain_graph=True)[0]
            loss_grad_fake_OCR = 10 ** 6 * torch.mean(grad_fake_OCR ** 2)
            loss_grad_fake_adv = 10 ** 6 * torch.mean(grad_fake_adv ** 2)
            with torch.no_grad():
                gen_enc_cost.backward()
        else:
            gen_enc_cost.backward()

        loss_avg_gen.add(disGenCost)
        loss_avg_ocr_unsup.add(opt.ocrWeight * ocrCost)

        optimizer.step()
        
        g_regularize = cntr % opt.g_reg_every == 0

        if g_regularize:
            path_batch_size = max(1, opt.batch_size // opt.path_batch_shrink)
            # image_input_tensors, _, labels, _ = iter(train_loader).next()
            labels_z_c = iter(text_loader).next()

            # image_input_tensors = image_input_tensors.to(device)
            # gt_image_tensors = image_input_tensors[:path_batch_size]

            # labels_gt = labels[:path_batch_size]

            text_z_c, length_z_c = converter.encode(labels_z_c[:path_batch_size], batch_max_length=opt.batch_max_length)
            # text_gt, length_gt = converter.encode(labels_gt, batch_max_length=opt.batch_max_length)
        
            z_c_code = cEncoder(text_z_c)
            noise_style = mixing_noise_style(path_batch_size, opt.latent, opt.mixing, device)
            style=[]
            style.append(noise_style[0]*z_c_code)
            if len(noise_style)>1:
                style.append(noise_style[1]*z_c_code)

            if opt.zAlone:
                #to validate orig style gan results
                newstyle = []
                newstyle.append(style[0][:,:opt.latent])
                if len(style)>1:
                    newstyle.append(style[1][:,:opt.latent])
                style = newstyle

            fake_img, grad = genModel(style, return_latents=True, g_path_regularize=True, mean_path_length=mean_path_length)
            
            decay = 0.01
            path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))

            mean_path_length_orig = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
            path_loss = (path_lengths - mean_path_length_orig).pow(2).mean()
            mean_path_length = mean_path_length_orig.detach().item()

            genModel.zero_grad()
            cEncoder.zero_grad()
            weighted_path_loss = opt.path_regularize * opt.g_reg_every * path_loss

            if opt.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            optimizer.step()

            # mean_path_length_avg = (
            #     reduce_sum(mean_path_length).item() / get_world_size()
            # )
            #commented above for multi-gpu , non-distributed setting
            mean_path_length_avg = mean_path_length

        accumulate(g_ema, genModel, accum)

        log_avg_path_loss_val.add(path_loss)
        log_avg_mean_path_length_avg.add(torch.tensor(mean_path_length_avg))
        log_ada_aug_p.add(torch.tensor(ada_aug_p))
        

        if get_rank() == 0:
            if wandb and opt.wandb:
                wandb.log(
                    {
                        "Generator": g_loss_val,
                        "Discriminator": d_loss_val,
                        "Augment": ada_aug_p,
                        "Rt": r_t_stat,
                        "R1": r1_val,
                        "Path Length Regularization": path_loss_val,
                        "Mean Path Length": mean_path_length,
                        "Real Score": real_score_val,
                        "Fake Score": fake_score_val,   
                        "Path Length": path_length_val,
                    }
                )
        
        # validation part
        if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' 
            
            #generate paired content with similar style
            labels_z_c_1 = iter(text_loader).next()
            labels_z_c_2 = iter(text_loader).next()
            
            text_z_c_1, length_z_c_1 = converter.encode(labels_z_c_1, batch_max_length=opt.batch_max_length)
            text_z_c_2, length_z_c_2 = converter.encode(labels_z_c_2, batch_max_length=opt.batch_max_length)

            z_c_code_1 = cEncoder(text_z_c_1)
            z_c_code_2 = cEncoder(text_z_c_2)

            
            style_c1_s1 = []
            style_c2_s1 = []
            style_s1 = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device)
            style_c1_s1.append(style_s1[0]*z_c_code_1)
            style_c2_s1.append(style_s1[0]*z_c_code_2)
            if len(style_s1)>1:
                style_c1_s1.append(style_s1[1]*z_c_code_1)
                style_c2_s1.append(style_s1[1]*z_c_code_2)
            
            noise_style = mixing_noise_style(opt.batch_size, opt.latent, opt.mixing, device)
            style_c1_s2 = []
            style_c1_s2.append(noise_style[0]*z_c_code_1)
            if len(noise_style)>1:
                style_c1_s2.append(noise_style[1]*z_c_code_1)
            
            if opt.zAlone:
                #to validate orig style gan results
                newstyle = []
                newstyle.append(style_c1_s1[0][:,:opt.latent])
                if len(style_c1_s1)>1:
                    newstyle.append(style_c1_s1[1][:,:opt.latent])
                style_c1_s1 = newstyle
                style_c2_s1 = newstyle
                style_c1_s2 = newstyle
            
            fake_img_c1_s1, _ = g_ema(style_c1_s1, input_is_latent=opt.input_latent)
            fake_img_c2_s1, _ = g_ema(style_c2_s1, input_is_latent=opt.input_latent)
            fake_img_c1_s2, _ = g_ema(style_c1_s2, input_is_latent=opt.input_latent)

            if not opt.zAlone:
                #Run OCR prediction
                if 'CTC' in opt.Prediction:
                    preds = ocrModel(fake_img_c1_s1, text_z_c_1, is_train=False)
                    preds_size = torch.IntTensor([preds.size(1)] * opt.batch_size)
                    _, preds_index = preds.max(2)
                    preds_str_fake_img_c1_s1 = converter.decode(preds_index.data, preds_size.data)

                    preds = ocrModel(fake_img_c2_s1, text_z_c_2, is_train=False)
                    preds_size = torch.IntTensor([preds.size(1)] * opt.batch_size)
                    _, preds_index = preds.max(2)
                    preds_str_fake_img_c2_s1 = converter.decode(preds_index.data, preds_size.data)

                    preds = ocrModel(fake_img_c1_s2, text_z_c_1, is_train=False)
                    preds_size = torch.IntTensor([preds.size(1)] * opt.batch_size)
                    _, preds_index = preds.max(2)
                    preds_str_fake_img_c1_s2 = converter.decode(preds_index.data, preds_size.data)

                    preds = ocrModel(gt_image_tensors, text_gt, is_train=False)
                    preds_size = torch.IntTensor([preds.size(1)] * gt_image_tensors.shape[0])
                    _, preds_index = preds.max(2)
                    preds_str_gt = converter.decode(preds_index.data, preds_size.data)

                else:
                    print("Not implemented error")
                    sys.exit()
            else:
                preds_str_fake_img_c1_s1 = [':None:'] * fake_img_c1_s1.shape[0]
                preds_str_gt = [':None:'] * fake_img_c1_s1.shape[0] 

            os.makedirs(os.path.join(opt.trainDir,str(iteration)), exist_ok=True)
            for trImgCntr in range(opt.batch_size):
                try:
                    save_image(tensor2im(fake_img_c1_s1[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_c1_s1_'+labels_z_c_1[trImgCntr]+'_ocr:'+preds_str_fake_img_c1_s1[trImgCntr]+'.png'))
                    if not opt.zAlone:
                        save_image(tensor2im(fake_img_c2_s1[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_c2_s1_'+labels_z_c_2[trImgCntr]+'_ocr:'+preds_str_fake_img_c2_s1[trImgCntr]+'.png'))
                        save_image(tensor2im(fake_img_c1_s2[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_c1_s2_'+labels_z_c_1[trImgCntr]+'_ocr:'+preds_str_fake_img_c1_s2[trImgCntr]+'.png'))
                        if trImgCntr<gt_image_tensors.shape[0]:
                            save_image(tensor2im(gt_image_tensors[trImgCntr].detach()),os.path.join(opt.trainDir,str(iteration),str(trImgCntr)+'_gt_act:'+labels_gt[trImgCntr]+'_ocr:'+preds_str_gt[trImgCntr]+'.png'))
                except:
                    print('Warning while saving training image')
            
            elapsed_time = time.time() - start_time
            # for log
            
            with open(os.path.join(opt.exp_dir,opt.exp_name,'log_train.txt'), 'a') as log:

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}]  \
                    Train Dis loss: {loss_avg_dis.val():0.5f}, Train Gen loss: {loss_avg_gen.val():0.5f},\
                    Train UnSup OCR loss: {loss_avg_ocr_unsup.val():0.5f}, Train Sup OCR loss: {loss_avg_ocr_sup.val():0.5f}, \
                    Train R1-val loss: {log_r1_val.val():0.5f}, Train avg-path-loss: {log_avg_path_loss_val.val():0.5f}, \
                    Train mean-path-length loss: {log_avg_mean_path_length_avg.val():0.5f}, Train ada-aug-p: {log_ada_aug_p.val():0.5f}, \
                    Elapsed_time: {elapsed_time:0.5f}'
                
                
                #plotting
                lib.plot.plot(os.path.join(opt.plotDir,'Train-Dis-Loss'), loss_avg_dis.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-Gen-Loss'), loss_avg_gen.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-UnSup-OCR-Loss'), loss_avg_ocr_unsup.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-Sup-OCR-Loss'), loss_avg_ocr_sup.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-r1_val'), log_r1_val.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-path_loss_val'), log_avg_path_loss_val.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-mean_path_length_avg'), log_avg_mean_path_length_avg.val().item())
                lib.plot.plot(os.path.join(opt.plotDir,'Train-ada_aug_p'), log_ada_aug_p.val().item())

                
                print(loss_log)

                loss_avg_dis.reset()
                loss_avg_gen.reset()
                loss_avg_ocr_unsup.reset()
                loss_avg_ocr_sup.reset()
                log_r1_val.reset()
                log_avg_path_loss_val.reset()
                log_avg_mean_path_length_avg.reset()
                log_ada_aug_p.reset()
                

            lib.plot.flush()

        lib.plot.tick()

        # save model per 1e+5 iter.
        if (iteration) % 1e+4 == 0:
            torch.save({
                'cEncoder':cEncoder.state_dict(),
                'genModel':genModel.state_dict(),
                'g_ema':g_ema.state_dict(),
                'ocrModel':ocrModel.state_dict(),
                'disEncModel':disEncModel.state_dict(),
                'optimizer':optimizer.state_dict(),
                'ocr_optimizer':ocr_optimizer.state_dict(),
                'dis_optimizer':dis_optimizer.state_dict()}, 
                os.path.join(opt.exp_dir,opt.exp_name,'iter_'+str(iteration+1)+'_synth.pth'))
            

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
        cntr+=1
Esempio n. 30
0
def validation(model, criterion, evaluation_loader, converter, opt):
    """ validation or evaluation """
    n_correct = 0
    norm_ED = 0
    length_of_data = 0
    infer_time = 0
    valid_loss_avg = Averager()

    for i, (image_tensors, labels) in enumerate(evaluation_loader):
        batch_size = image_tensors.size(0)
        length_of_data = length_of_data + batch_size
        image = image_tensors.to(device)
        # For max length prediction
        length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
        text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)

        if opt.SequenceModeling == 'Transformer':
            text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=opt.batch_max_length,
                                                              train=False)
        else:
            text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=opt.batch_max_length,)

        start_time = time.time()
        preds_prob = None
        if 'CTC' in opt.Prediction:
            preds = model(image, text_for_pred)
            forward_time = time.time() - start_time

            # Calculate evaluation loss for CTC deocder.
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            # permute 'preds' to use CTCloss format
            if opt.baiduCTC:
                cost = criterion(preds.permute(1, 0, 2), text_for_loss, preds_size, length_for_loss) / batch_size
            else:
                cost = criterion(preds.log_softmax(2).permute(1, 0, 2), text_for_loss, preds_size, length_for_loss)

            # Select max probabilty (greedy decoding) then decode index to character
            if opt.baiduCTC:
                _, preds_index = preds.max(2)
                preds_index = preds_index.view(-1)
            else:
                _, preds_index = preds.max(2)
                preds_index = preds_index.view(-1)
            preds_str = converter.decode(preds_index.data, preds_size.data)

        elif opt.Prediction == 'None':
            tgt_input = text_for_loss['tgt_input']
            tgt_output = text_for_loss['tgt_output']
            tgt_padding_mask = text_for_loss['tgt_padding_mask']
            preds = model(image, tgt_input.transpose(0, 1), tgt_key_padding_mask=tgt_padding_mask,)
            forward_time = time.time() - start_time
            cost = criterion(preds.view(-1, preds.shape[-1]), tgt_output.contiguous().view(-1))

            preds_str, preds_prob = predict(model, image, converter, beam_search=opt.beam_search, max_seq_length=64)
            labels = converter.decode(tgt_input)

        else:
            preds = model(image, text_for_pred, is_train=False)
            forward_time = time.time() - start_time

            preds = preds[:, :text_for_loss.shape[1] - 1, :]
            target = text_for_loss[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.contiguous().view(-1, preds.shape[-1]), target.contiguous().view(-1))

            # select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            preds_str = converter.decode(preds_index, length_for_pred)
            labels = converter.decode(text_for_loss[:, 1:], length_for_loss)

        infer_time += forward_time
        valid_loss_avg.add(cost)

        # calculate accuracy & confidence score
        if opt.Prediction == 'None':
            preds_max_prob = preds_prob if preds_prob is not None else np.zeros(len(labels))
        else:
            preds_prob = F.softmax(preds, dim=2)
            preds_max_prob, _ = preds_prob.max(dim=2)
        confidence_score_list = []
        for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob):
            # print(f'{gt:25s}\t{pred:25s}\t{str(gt == pred)}')
            if 'Attn' in opt.Prediction:
                gt = gt[:gt.find('[s]')]
                pred_EOS = pred.find('[s]')
                pred = pred[:pred_EOS]  # prune after "end of sentence" token ([s])
                pred_max_prob = pred_max_prob[:pred_EOS]

            # To evaluate 'case sensitive model' with alphanumeric and case insensitve setting.
            if opt.sensitive and opt.data_filtering_off:
                pred = pred.lower()
                gt = gt.lower()
                alphanumeric_case_insensitve = '0123456789abcdefghijklmnopqrstuvwxyz'
                out_of_alphanumeric_case_insensitve = f'[^{alphanumeric_case_insensitve}]'
                pred = re.sub(out_of_alphanumeric_case_insensitve, '', pred)
                gt = re.sub(out_of_alphanumeric_case_insensitve, '', gt)

            if pred == gt:
                n_correct += 1

            '''
            (old version) ICDAR2017 DOST Normalized Edit Distance https://rrc.cvc.uab.es/?ch=7&com=tasks
            "For each word we calculate the normalized edit distance to the length of the ground truth transcription."
            if len(gt) == 0:
                norm_ED += 1
            else:
                norm_ED += edit_distance(pred, gt) / len(gt)
            '''

            # ICDAR2019 Normalized Edit Distance
            if len(gt) == 0 or len(pred) == 0:
                norm_ED += 0
            elif len(gt) > len(pred):
                norm_ED += 1 - edit_distance(pred, gt) / len(gt)
            else:
                norm_ED += 1 - edit_distance(pred, gt) / len(pred)

            # calculate confidence score (= multiply of pred_max_prob)
            try:
                confidence_score = pred_max_prob.cumprod(dim=0)[-1]
            except:
                confidence_score = 0  # for empty pred case, when prune after "end of sentence" token ([s])
            confidence_score_list.append(confidence_score)
            # print(pred, gt, pred==gt, confidence_score)

    accuracy = n_correct / float(length_of_data) * 100
    norm_ED = norm_ED / float(length_of_data)  # ICDAR2019 Normalized Edit Distance

    return valid_loss_avg.val(), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data