예제 #1
0
def train(opt):
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    """ model configuration """
    if 'CTC' in opt.Prediction:
        if opt.baiduCTC:
            converter = CTCLabelConverterForBaiduWarpctc(opt.character)
        else:
            converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    model = torch.nn.DataParallel(model).to(device)
    model.train()
    if opt.saved_model != '':
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_model), strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_model))
    print("Model:")
    print(model)
    """ setup loss """
    if 'CTC' in opt.Prediction:
        if opt.baiduCTC:
            # need to install warpctc. see our guideline.
            from warpctc_pytorch import CTCLoss
            criterion = CTCLoss()
        else:
            criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
            device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    print(optimizer)
    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0
    if opt.saved_model != '':
        try:
            start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    iteration = start_iter

    while (True):
        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        text, length = converter.encode(labels,
                                        batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)

        if 'CTC' in opt.Prediction:
            preds = model(image, text)
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            if opt.baiduCTC:
                preds = preds.permute(1, 0, 2)  # to use CTCLoss format
                cost = criterion(preds, text, preds_size, length) / batch_size
            else:
                preds = preds.log_softmax(2).permute(1, 0, 2)
                cost = criterion(preds, text, preds_size, length)

        else:
            preds = model(image, text[:, :-1])  # align with Attention.forward
            target = text[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.view(-1, preds.shape[-1]),
                             target.contiguous().view(-1))

        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)

        # validation part
        if (
                iteration + 1
        ) % opt.valInterval == 0 or iteration == 0:  # To see training progress, we also conduct validation when 'iteration == 0'
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.exp_name}/log_train.txt',
                      'a') as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, criterion, valid_loader, converter, opt)
                model.train()

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.exp_name}/best_accuracy.pth')
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.exp_name}/best_norm_ED.pth')
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                for gt, pred, confidence in zip(labels[:5], preds[:5],
                                                confidence_score[:5]):
                    if 'Attn' in opt.Prediction:
                        gt = gt[:gt.find('[s]')]
                        pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')

        # save model per 1e+5 iter.
        if (iteration + 1) % 1e+5 == 0:
            torch.save(
                model.state_dict(),
                f'./saved_models/{opt.exp_name}/iter_{iteration+1}.pth')

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
def train(opt):
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '	')
    log.close()
    """ model configuration """
    # if 'CTC' in opt.Prediction:
    if opt.baiduCTC:
        CTC_converter = CTCLabelConverterForBaiduWarpctc(opt.character)
    else:
        CTC_converter = CTCLabelConverter(opt.character)


# else:
    Attn_converter = AttnLabelConverter(opt.character)
    opt.num_class_ctc = len(CTC_converter.character)
    opt.num_class_attn = len(Attn_converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class_ctc, opt.num_class_attn, opt.batch_max_length,
          opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling,
          opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        # print(name)
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    model = torch.nn.DataParallel(model).to(device)
    model.train()
    print("Model:")
    print(model)
    # print(summary(model, (1, opt.imgH, opt.imgW,1)))
    """ setup loss """
    if opt.baiduCTC:
        # need to install warpctc. see our guideline.
        if opt.label_smooth:
            criterion_major_path = SmoothCTCLoss(num_classes=opt.num_class_ctc,
                                                 weight=0.05)
        else:
            criterion_major_path = CTCLoss()
        #criterion_major_path = CTCLoss(average_frames=False, reduction="mean", blank=0)
    else:
        criterion_major_path = torch.nn.CTCLoss(zero_infinity=True).to(device)
    # else:
    #     criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)  # ignore [GO] token = ignore index 0
    # loss averager
    #criterion_major_path = torch.nn.CTCLoss(zero_infinity=True).to(device)
    criterion_guide_path = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)
    loss_avg_major_path = Averager()
    loss_avg_guide_path = Averager()
    # filter that only require gradient decent
    guide_parameters = []
    major_parameters = []
    guide_model_part_names = [
        "Transformation", "FeatureExtraction", "SequenceModeling_Attn",
        "Attention"
    ]
    major_model_part_names = ["SequenceModeling_CTC", "CTC"]
    for name, param in model.named_parameters():
        if param.requires_grad:
            if name.split(".")[1] in guide_model_part_names:
                guide_parameters.append(param)
            elif name.split(".")[1] in major_model_part_names:
                major_parameters.append(param)
            # print(name)
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]
    if opt.continue_training:
        guide_parameters = []
    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer_ctc = AdamW(major_parameters, lr=opt.lr)
        if not opt.continue_training:
            optimizer_attn = AdamW(guide_parameters, lr=opt.lr)
    scheduler_ctc = get_linear_schedule_with_warmup(
        optimizer_ctc, num_warmup_steps=10000, num_training_steps=opt.num_iter)
    scheduler_attn = get_linear_schedule_with_warmup(
        optimizer_attn,
        num_warmup_steps=10000,
        num_training_steps=opt.num_iter)
    start_iter = 0
    if opt.saved_model != '' and (not opt.continue_training):
        print(f'loading pretrained model from {opt.saved_model}')
        checkpoint = torch.load(opt.saved_model)
        start_iter = checkpoint['start_iter'] + 1
        if not opt.adam:
            optimizer_ctc.load_state_dict(
                checkpoint['optimizer_ctc_state_dict'])
            if not opt.continue_training:
                optimizer_attn.load_state_dict(
                    checkpoint['optimizer_attn_state_dict'])
            scheduler_ctc.load_state_dict(
                checkpoint['scheduler_ctc_state_dict'])
            scheduler_attn.load_state_dict(
                checkpoint['scheduler_attn_state_dict'])
            print(scheduler_ctc.get_lr())
            print(scheduler_attn.get_lr())
        if opt.FT:
            model.load_state_dict(checkpoint['model_state_dict'], strict=False)
        else:
            model.load_state_dict(checkpoint['model_state_dict'])
    if opt.continue_training:
        model.load_state_dict(torch.load(opt.saved_model))
    # print("Optimizer:")
    # print(optimizer)
    #
    scheduler_ctc = get_linear_schedule_with_warmup(
        optimizer_ctc,
        num_warmup_steps=10000,
        num_training_steps=opt.num_iter,
        last_epoch=start_iter - 1)
    scheduler_attn = get_linear_schedule_with_warmup(
        optimizer_attn,
        num_warmup_steps=10000,
        num_training_steps=opt.num_iter,
        last_epoch=start_iter - 1)
    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file:
        opt_log = '------------ Options -------------	'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}	'
        opt_log += '---------------------------------------	'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    iteration = start_iter - 1
    if opt.continue_training:
        start_iter = 0
    while (True):
        # train part
        image_tensors, labels = train_dataset.get_batch()
        iteration += 1
        if iteration < start_iter:
            continue
        image = image_tensors.to(device)
        # print(image.size())
        text_attn, length_attn = Attn_converter.encode(
            labels, batch_max_length=opt.batch_max_length)
        #print("1")
        text_ctc, length_ctc = CTC_converter.encode(
            labels, batch_max_length=opt.batch_max_length)
        #print("2")
        #if iteration == start_iter :
        #    writer.add_graph(model, (image, text_attn))
        batch_size = image.size(0)
        preds_major, preds_guide = model(image, text_attn[:, :-1])
        #print("10")
        preds_size = torch.IntTensor([preds_major.size(1)] * batch_size)
        if opt.baiduCTC:
            preds_major = preds_major.permute(1, 0, 2)  # to use CTCLoss format
            if opt.label_smooth:
                cost_ctc = criterion_major_path(preds_major, text_ctc,
                                                preds_size, length_ctc,
                                                batch_size)
            else:
                cost_ctc = criterion_major_path(
                    preds_major, text_ctc, preds_size, length_ctc) / batch_size
        else:
            preds_major = preds_major.log_softmax(2).permute(1, 0, 2)
            cost_ctc = criterion_major_path(preds_major, text_ctc, preds_size,
                                            length_ctc)
        #print("3")
        # preds = model(image, text[:, :-1])  # align with Attention.forward
        target = text_attn[:, 1:]  # without [GO] Symbol
        if not opt.continue_training:
            cost_attn = criterion_guide_path(
                preds_guide.view(-1, preds_guide.shape[-1]),
                target.contiguous().view(-1))
            optimizer_attn.zero_grad()
            cost_attn.backward(retain_graph=True)
            torch.nn.utils.clip_grad_norm_(
                guide_parameters,
                opt.grad_clip)  # gradient clipping with 5 (Default)
            optimizer_attn.step()
        optimizer_ctc.zero_grad()
        cost_ctc.backward()
        torch.nn.utils.clip_grad_norm_(
            major_parameters,
            opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer_ctc.step()
        scheduler_ctc.step()
        scheduler_attn.step()
        #print("4")
        loss_avg_major_path.add(cost_ctc)
        if not opt.continue_training:
            loss_avg_guide_path.add(cost_attn)
        if (iteration + 1) % 100 == 0:
            writer.add_scalar("Loss/train_ctc", loss_avg_major_path.val(),
                              (iteration + 1) // 100)
            loss_avg_major_path.reset()
            if not opt.continue_training:
                writer.add_scalar("Loss/train_attn", loss_avg_guide_path.val(),
                                  (iteration + 1) // 100)
                loss_avg_guide_path.reset()
        # validation part
        if (
                iteration + 1
        ) % opt.valInterval == 0:  #or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0'
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.exp_name}/log_train.txt',
                      'a') as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, criterion_major_path, valid_loader,
                        CTC_converter, opt)
                model.train()
                writer.add_scalar("Loss/valid", valid_loss,
                                  (iteration + 1) // opt.valInterval)
                writer.add_scalar("Metrics/accuracy", current_accuracy,
                                  (iteration + 1) // opt.valInterval)
                writer.add_scalar("Metrics/norm_ED", current_norm_ED,
                                  (iteration + 1) // opt.valInterval)
                # loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {train_loss:0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                # loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'
                # training loss and validation loss
                if not opt.continue_training:
                    loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss ctc: {loss_avg_major_path.val():0.5f}, Train loss attn: {loss_avg_guide_path.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                else:
                    loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss ctc: {loss_avg_major_path.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                loss_avg_major_path.reset()
                if not opt.continue_training:
                    loss_avg_guide_path.reset()
                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(model.state_dict(),
                               f'{fol_ckpt}/best_accuracy.pth')
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(model.state_dict(),
                               f'{fol_ckpt}/best_norm_ED.pth')
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f'{loss_log}	{current_model_log}	{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '	')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}	{head}	{dashed_line}	'
                for gt, pred, confidence in zip(labels[:5], preds[:5],
                                                confidence_score[:5]):
                    # if 'Attn' in opt.Prediction:
                    #     gt = gt[:gt.find('[s]')]
                    #     pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}	{str(pred == gt)}	'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '	')

        # save model per 1e+5 iter.
        if (iteration + 1) % 1e+3 == 0 and (not opt.continue_training):
            # print(scheduler_ctc.get_lr())
            # print(scheduler_attn.get_lr())
            torch.save(
                {
                    'model_state_dict': model.state_dict(),
                    'optimizer_attn_state_dict': optimizer_attn.state_dict(),
                    'optimizer_ctc_state_dict': optimizer_ctc.state_dict(),
                    'start_iter': iteration,
                    'scheduler_ctc_state_dict': scheduler_ctc.state_dict(),
                    'scheduler_attn_state_dict': scheduler_attn.state_dict(),
                }, f'{fol_ckpt}/current_model.pth')

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
예제 #3
0
def test(opt):
    """ model configuration """
    # if 'CTC' in opt.Prediction:
    if opt.baiduCTC:
        converter_ctc = CTCLabelConverterForBaiduWarpctc(opt.character)
    else:
        converter_ctc = CTCLabelConverter(opt.character)


# else:
    converter_attn = AttnLabelConverter(opt.character)
    opt.num_class_ctc = len(converter_ctc.character)
    opt.num_class_attn = len(converter_attn.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    # print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
    #       opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
    #       opt.SequenceModeling, opt.Prediction)
    model = torch.nn.DataParallel(model).to(device)

    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model, map_location=device),
                          strict=False)
    opt.exp_name = '_'.join(opt.saved_model.split('/')[1:])
    # print(model)
    """ keep evaluation model and result logs """
    os.makedirs(f'./result/{opt.exp_name}', exist_ok=True)
    os.system(f'cp {opt.saved_model} ./result/{opt.exp_name}/')
    """ setup loss """
    # if 'CTC' in opt.Prediction:
    criterion_ctc = torch.nn.CTCLoss(zero_infinity=True).to(device)
    # else:
    criterion_attn = torch.nn.CrossEntropyLoss(ignore_index=0).to(
        device)  # ignore [GO] token = ignore index 0
    """ evaluation """
    model.eval()
    with torch.no_grad():
        if opt.benchmark_all_eval:  # evaluation with 10 benchmark evaluation datasets
            log = open(f'./result/{opt.exp_name}/log_evaluation.txt', 'a')
            AlignCollate_evaluation = AlignCollate(imgH=opt.imgH,
                                                   imgW=opt.imgW,
                                                   keep_ratio_with_pad=opt.PAD)
            eval_data, eval_data_log = hierarchical_dataset(root=opt.eval_data,
                                                            opt=opt)
            evaluation_loader = torch.utils.data.DataLoader(
                eval_data,
                batch_size=opt.batch_size,
                shuffle=False,
                num_workers=int(opt.workers),
                collate_fn=AlignCollate_evaluation,
                pin_memory=True)
            benchmark_all_eval(model, criterion_ctc, criterion_attn,
                               evaluation_loader, converter_ctc,
                               converter_attn, opt)
        else:
            log = open(f'./result/{opt.exp_name}/log_evaluation.txt', 'a')
            AlignCollate_evaluation = AlignCollate(imgH=opt.imgH,
                                                   imgW=opt.imgW,
                                                   keep_ratio_with_pad=opt.PAD)
            eval_data, eval_data_log = hierarchical_dataset(root=opt.eval_data,
                                                            opt=opt)
            evaluation_loader = torch.utils.data.DataLoader(
                eval_data,
                batch_size=opt.batch_size,
                shuffle=False,
                num_workers=int(opt.workers),
                collate_fn=AlignCollate_evaluation,
                pin_memory=True)
            _, accuracy_by_best_model, _, _, _, _, _, _, _, acc_attn, _, _, _ = validation_ctc_and_attn(
                model, criterion_ctc, criterion_attn, evaluation_loader,
                converter_ctc, converter_attn, opt)
            log.write(eval_data_log)
            print(f'{accuracy_by_best_model:0.3f}')
            print(f'{acc_attn:0.3f}')
            log.write(f'{accuracy_by_best_model:0.3f}	')
            log.close()
def demo(opt):
    """ model configuration """
    if opt.guide_training :
      from model_guide import Model
    else :
      from model import Model
    if opt.baiduCTC:
        converter = CTCLabelConverterForBaiduWarpctc(opt.character)
    else :
        converter = CTCLabelConverter(opt.character)
    if opt.Prediction == 'Attn' :
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)
    opt.num_class_ctc = opt.num_class
    opt.num_class_attn = opt.num_class_ctc + 1

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
          opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
          opt.SequenceModeling, opt.Prediction)
    model = torch.nn.DataParallel(model).to(device)

    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model, map_location=device), strict = False)

    # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
    AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
    demo_data = RawDataset(root=opt.image_folder, opt=opt)  # use RawDataset
    demo_loader = torch.utils.data.DataLoader(
        demo_data, batch_size=opt.batch_size,
        shuffle=False,
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_demo, pin_memory=True)

    # predict
    model.eval()
    data = pd.DataFrame()
    with torch.no_grad():
        ind = 0
        for image_tensors, image_path_list in demo_loader:
            batch_size = image_tensors.size(0)
            image = image_tensors.to(device)
            # For max length prediction
            length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
            text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)

            if 'CTC' in opt.Prediction:
                if opt.guide_training :
                    preds = model.module.inference(image, text_for_pred)
                else :
                    preds = model(image, text_for_pred)

                # Select max probabilty (greedy decoding) then decode index to character
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)

                # Select max probabilty (greedy decoding) then decode index to character
                if opt.baiduCTC:
                    if (opt.beam_search):
                      preds_index = preds
                    else :
                      _, preds_index = preds.max(2)
                      preds_index = preds_index.view(-1)
                else:
                    _, preds_index = preds.max(2)
                preds_str = converter.decode(preds_index.data, preds_size.data,opt.beam_search)
            else:
                preds = model(image, text_for_pred, is_train=False)

                # select max probabilty (greedy decoding) then decode index to character
                _, preds_index = preds.max(2)
                preds_str = converter.decode(preds_index, length_for_pred)


            log = open(f'./log_demo_result.txt', 'a')
            dashed_line = '-' * 80
            head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score'
            
            print(f'{dashed_line}\n{head}\n{dashed_line}')
            log.write(f'{dashed_line}\n{head}\n{dashed_line}\n')

            preds_prob = F.softmax(preds, dim=2)
            preds_max_prob, _ = preds_prob.max(dim=2)
            for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob):
                if 'Attn' in opt.Prediction:
                    pred_EOS = pred.find('[s]')
                    pred = pred[:pred_EOS]  # prune after "end of sentence" token ([s])
                    pred_max_prob = pred_max_prob[:pred_EOS]

                # calculate confidence score (= multiply of pred_max_prob)
                confidence_score = pred_max_prob.cumprod(dim=0)[-1]
                filename = img_name
                label = pred
                conf = round(confidence_score.item(),3)
                img = cv2.imread(filename)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                img_pil = Image.fromarray(img)
                img_buffer = io.BytesIO()
                img_pil.save(img_buffer, format="PNG")
                imgStr = base64.b64encode(img_buffer.getvalue()).decode("utf-8") 

                data.loc[ind, 'img'] = '<img src="data:image/png;base64,{0:s}">'.format(imgStr)
                data.loc[ind, 'id'] = filename
                data.loc[ind, 'label'] = label
                data.loc[ind, 'conf'] = conf
                ind+=1
                print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}')
                log.write(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n')

            log.close()
        html_all = data.to_html(escape=False)
        if opt.is_save :
            text_file = open("result.html", "w") 
            text_file.write(html_all) 
            text_file.close() 
def train(opt):
    """ dataset preparation """
    if not opt.data_filtering_off:
        print('Filtering the images containing characters which are not in opt.character')
        print('Filtering the images whose label is longer than opt.batch_max_length')
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    # opt.select_data = opt.select_data.split('-')#[MJ,ST]
    # opt.batch_ratio = opt.batch_ratio.split('-')#[0.5,0.5]
    # train_dataset = Batch_Balanced_Dataset(opt)

    # log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
    # valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt)
    # train_dataset = iiit5k_dataset_builder("/media/ps/hd1/lll/textRecognition/SAR/IIIT5K/train",
    #     "/media/ps/hd1/lll/textRecognition/SAR/IIIT5K/traindata.mat",opt)
    
    # train_dataset = PpocrDataset("/home/ldl/桌面/论文/文本识别/data/paddleocr",
    # "/home/ldl/桌面/论文/文本识别/data/paddleocr/label/train.txt",6625*100)
    


    # train_dataset_chinese = TextRecognition(4068*75,opt.charalength,opt.chinesefile)
    # train_dataset_english = TextRecognition(4068*25,opt.charalength,opt.englishfile)
    # train_dataset = ConcatDataset([train_dataset_chinese,train_dataset_english])
    train_dataset_xunfeieng = mytrdg_cutimg_dataset(total_img_path='/home/ldl/桌面/论文/文本识别/data/finish_data/eng_image/train/img',
        annotation_path='/home/ldl/桌面/论文/文本识别/data/finish_data/eng_image/train/gt')
    train_dataset_xunfeichn = mytrdg_cutimg_dataset(total_img_path='/home/ldl/桌面/论文/文本识别/data/finish_data/lan_image/train/img',
        annotation_path='/home/ldl/桌面/论文/文本识别/data/finish_data/lan_image/train/gt')
    train_dataset = ConcatDataset([train_dataset_xunfeichn,train_dataset_xunfeieng])
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=opt.batch_size,
        shuffle=True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid)

    # valid_dataset = iiit5k_dataset_builder("/media/ps/hd1/lll/textRecognition/SAR/IIIT5K/test",
    #     "/media/ps/hd1/lll/textRecognition/SAR/IIIT5K/testdata.mat",opt)
    # valid_dataset_chinese = TextRecognition(1001,opt.charalength,opt.chinesefile)
    # valid_dataset_english = TextRecognition(1001,opt.charalength,opt.englishfile)
    # valid_dataset = ConcatDataset([valid_dataset_chinese,valid_dataset_english])

    valid_dataset_xunfeieng = mytrdg_cutimg_dataset(total_img_path='/home/ldl/桌面/论文/文本识别/data/finish_data/eng_image/test/img',
        annotation_path='/home/ldl/桌面/论文/文本识别/data/finish_data/eng_image/test/gt')
    valid_dataset_xunfeichn = mytrdg_cutimg_dataset(total_img_path='/home/ldl/桌面/论文/文本识别/data/finish_data/lan_image/test/img',
        annotation_path='/home/ldl/桌面/论文/文本识别/data/finish_data/lan_image/test/gt')
    valid_dataset = ConcatDataset([valid_dataset_xunfeichn,valid_dataset_xunfeieng])

    # valid_dataset = PpocrDataset("/home/ldl/桌面/论文/文本识别/data/paddleocr/Synthetic_Chinese_String_Dataset/images",
    # "/home/ldl/桌面/论文/文本识别/data/paddleocr/Synthetic_Chinese_String_Dataset/test.txt",6625,
    #     split='jpg')
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=opt.batch_size,
        shuffle=True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid)
    # log.write(valid_dataset_log)
    print('-' * 80)
    # log.write('-' * 80 + '\n')
    # log.close()
    
    """ model configuration """
    if 'CTC' in opt.Prediction:
        if opt.baiduCTC:
            converter = CTCLabelConverterForBaiduWarpctc(opt.character)
        else:
            converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
          opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
          opt.SequenceModeling, opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    if opt.num_gpu > 1:
    
        model = torch.nn.DataParallel(model).to(device)
    else:
        model.to(device)
    model.train()
    if opt.saved_model != '':
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_model), strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_model))
    print("Model:")
    # print(model)

    """ setup loss """
    if 'CTC' in opt.Prediction:
        if opt.baiduCTC:
            # need to install warpctc. see our guideline.
            from warpctc_pytorch import CTCLoss 
            criterion = CTCLoss()
        else:
            criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    print(f"Train dataset length {len(train_dataset)}")
    print(f"Train dataset length {len(valid_dataset)}")
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    # if opt.adam:
    #     optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999))
    # else:
    #     optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps)
    optimizer = optim.Adam(filtered_parameters,lr=0.0001)
    print("Optimizer:")
    print(optimizer)

    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)

    """ start training """
    start_iter = 0
    # if opt.saved_model != '':
    #     try:
    #         start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
    #         print(f'continue to train, start_iter: {start_iter}')
    #     except:
    #         pass

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    iteration = start_iter
    epoch = 0
    train_iter_loader = iter(train_loader)

    while(True):
        # train part
        try:

            image_tensors, labels = train_iter_loader.next()
            # if len(labels)>80:
            #     print(labels)
            print("{:4}".format(iteration),end='\r')
        except StopIteration:
            epoch += 1
            print(f"epoch:{epoch}")
            # if epoch >= 1:
            #     break
            # train_loader = torch.utils.data.DataLoader(
            #     train_dataset, batch_size=opt.batch_size,
            #     shuffle=True,  # 'True' to check training progress with validation function.
            #     num_workers=int(opt.workers),
            #     collate_fn=AlignCollate_valid)
            train_iter_loader = iter(train_loader)
            continue
        image = image_tensors.to(device)
        text, length = converter.encode(labels, batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)

        if 'CTC' in opt.Prediction:
            preds = model(image, text)
            preds_size = torch.IntTensor([preds.size(1)] * preds.size(0))
            if opt.baiduCTC:
                preds = preds.permute(1, 0, 2)  # to use CTCLoss format
                cost = criterion(preds, text, preds_size, length) / batch_size
            else:
                preds = preds.log_softmax(2).permute(1, 0, 2)
                try:
                    cost = criterion(preds, text, preds_size, length)
                except Exception:
                    print(preds.shape,preds_size.shape)
                    raise ''

        else:
            preds = model(image, text[:, :-1])  # align with Attention.forward
            target = text[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))

        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)

        # validation part
        if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' 
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log:
                model.eval()
                print("validation")
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, criterion, valid_loader, converter, opt)
                model.train()

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy >= best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_accuracy.pth')
                if current_norm_ED >= best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_norm_ED.pth')
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]):
                    if 'Attn' in opt.Prediction:
                        gt = gt[:gt.find('[s]')]
                        pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')

        # save model per 1e+5 iter.
        if (iteration + 1) % 1e+4 == 0:
            torch.save(
                model.state_dict(), f'./saved_models/{opt.exp_name}/iter_{iteration+1}.pth')
        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
def train(opt):
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    # train_dataset (image, label)
    train_dataset = Batch_Balanced_Dataset(opt)
    log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    """ model configuration """
    if 'CTC' in opt.Prediction:
        if opt.baiduCTC:
            converter = CTCLabelConverterForBaiduWarpctc(opt.character)
        else:
            converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)

    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)

    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)

    print("Model:")
    print(model)

    total_num, true_grad_num, false_grad_num = calculate_model_params(model)
    print("Total parameters: ", total_num)
    print("Number of parameters requires grad: ", true_grad_num)
    print("Number of parameters do not require grad: ", false_grad_num)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    model = torch.nn.DataParallel(model).to(device)
    model.train()
    if isinstance(model, torch.nn.DataParallel):
        model = model.module
    # load pretrained model
    if opt.saved_model != '':
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.FT:
            model.load_pretrained_networks()
        elif opt.continue_train:
            model.load_checkpoint(opt.model_name)
        else:
            raise Exception('Something went wrong!')
    """ setup loss """
    if 'CTC' in opt.Prediction:
        if opt.baiduCTC:
            # need to install warpctc. see our guideline.
            from warpctc_pytorch import CTCLoss
            criterion = CTCLoss()
        else:
            criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
            device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()
    log_dir = f'./saved_models/{opt.exp_name}'
    writer = SummaryWriter(log_dir)

    # """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0
    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    iteration = start_iter

    while (True):
        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        text, length = converter.encode(labels,
                                        batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)

        if 'CTC' in opt.Prediction:
            preds = model(image, text)
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            if opt.baiduCTC:
                preds = preds.permute(1, 0, 2)  # to use CTCLoss format
                cost = criterion(preds, text, preds_size, length) / batch_size
            else:
                preds = preds.log_softmax(2).permute(1, 0, 2)
                cost = criterion(preds, text, preds_size, length)

        else:
            preds = model(image, text[:, :-1])  # align with Attention.forward
            target = text[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.view(-1, preds.shape[-1]),
                             target.contiguous().view(-1))

        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        model.optimize_parameters()
        writer.add_scalar('train_loss', cost, iteration + 1)
        loss_avg.add(cost)

        # validation part
        if (
                iteration + 1
        ) % opt.valInterval == 0 or iteration == 0:  # To see training progress, we also conduct validation when 'iteration == 0'
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.exp_name}/log_train.txt',
                      'a') as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, criterion, valid_loader, converter, opt,
                        iteration)
                model.train()

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                writer.add_scalar('val_loss', valid_loss, iteration + 1)
                writer.add_scalar('accuracy', current_accuracy, iteration + 1)
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    model.save_checkpoints(iteration, 'best_accuracy.pth')
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    model.save_checkpoints(iteration, 'best_norm_ED.pth')
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                for gt, pred, confidence in zip(labels[:5], preds[:5],
                                                confidence_score[:5]):
                    if 'Attn' in opt.Prediction:
                        gt = gt[:gt.find('[s]')]
                        pred = pred[:pred.find('[s]')]
                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')

        # save model per 1e+5 iter.
        if (iteration + 1) % 1e+5 == 0:
            model.save_checkpoints(iteration + 1, opt.model_name)

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1