Exemple #1
0
def train(opt, log):
    """dataset preparation"""
    # train dataset. for convenience
    if opt.select_data == "label":
        select_data = [
            "1.SVT",
            "2.IIIT",
            "3.IC13",
            "4.IC15",
            "5.COCO",
            "6.RCTW17",
            "7.Uber",
            "8.ArT",
            "9.LSVT",
            "10.MLT19",
            "11.ReCTS",
        ]

    elif opt.select_data == "synth":
        select_data = ["MJ", "ST"]

    elif opt.select_data == "synth_SA":
        select_data = ["MJ", "ST", "SA"]
        opt.batch_ratio = "0.4-0.4-0.2"  # same ratio with SCATTER paper.

    elif opt.select_data == "mix":
        select_data = [
            "1.SVT",
            "2.IIIT",
            "3.IC13",
            "4.IC15",
            "5.COCO",
            "6.RCTW17",
            "7.Uber",
            "8.ArT",
            "9.LSVT",
            "10.MLT19",
            "11.ReCTS",
            "MJ",
            "ST",
        ]

    elif opt.select_data == "mix_SA":
        select_data = [
            "1.SVT",
            "2.IIIT",
            "3.IC13",
            "4.IC15",
            "5.COCO",
            "6.RCTW17",
            "7.Uber",
            "8.ArT",
            "9.LSVT",
            "10.MLT19",
            "11.ReCTS",
            "MJ",
            "ST",
            "SA",
        ]

    else:
        select_data = opt.select_data.split("-")

    # set batch_ratio for each data.
    if opt.batch_ratio:
        batch_ratio = opt.batch_ratio.split("-")
    else:
        batch_ratio = [round(1 / len(select_data), 3)] * len(select_data)

    train_loader = Batch_Balanced_Dataset(opt, opt.train_data, select_data,
                                          batch_ratio, log)

    if opt.semi != "None":
        select_data_unlabel = ["U1.Book32", "U2.TextVQA", "U3.STVQA"]
        batch_ratio_unlabel = [round(1 / len(select_data_unlabel), 3)
                               ] * len(select_data_unlabel)
        dataset_root_unlabel = "data_CVPR2021/training/unlabel/"
        train_loader_unlabel_semi = Batch_Balanced_Dataset(
            opt,
            dataset_root_unlabel,
            select_data_unlabel,
            batch_ratio_unlabel,
            log,
            learn_type="semi",
        )

    AlignCollate_valid = AlignCollate(opt, mode="test")
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt, mode="test")
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=False,
    )
    log.write(valid_dataset_log)
    print("-" * 80)
    log.write("-" * 80 + "\n")
    """ model configuration """
    if "CTC" in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
        opt.sos_token_index = converter.dict["[SOS]"]
        opt.eos_token_index = converter.dict["[EOS]"]
    opt.num_class = len(converter.character)

    model = Model(opt)

    # weight initialization
    for name, param in model.named_parameters():
        if "localization_fc2" in name:
            print(f"Skip {name} as it is already initialized")
            continue
        try:
            if "bias" in name:
                init.constant_(param, 0.0)
            elif "weight" in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if "weight" in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    model = torch.nn.DataParallel(model).to(device)
    model.train()
    if opt.saved_model != "":
        fine_tuning_log = f"### loading pretrained model from {opt.saved_model}\n"

        if "MoCo" in opt.saved_model or "MoCo" in opt.self_pre:
            pretrained_state_dict_qk = torch.load(opt.saved_model)
            pretrained_state_dict = {}
            for name in pretrained_state_dict_qk:
                if "encoder_q" in name:
                    rename = name.replace("encoder_q.", "")
                    pretrained_state_dict[rename] = pretrained_state_dict_qk[
                        name]
        else:
            pretrained_state_dict = torch.load(opt.saved_model)

        for name, param in model.named_parameters():
            try:
                param.data.copy_(pretrained_state_dict[name].data
                                 )  # load from pretrained model
                if opt.FT == "freeze":
                    param.requires_grad = False  # Freeze
                    fine_tuning_log += f"pretrained layer (freezed): {name}\n"
                else:
                    fine_tuning_log += f"pretrained layer: {name}\n"
            except:
                fine_tuning_log += f"non-pretrained layer: {name}\n"

        print(fine_tuning_log)
        log.write(fine_tuning_log + "\n")

    # print("Model:")
    # print(model)
    log.write(repr(model) + "\n")
    """ setup loss """
    if "CTC" in opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        # ignore [PAD] token
        criterion = torch.nn.CrossEntropyLoss(
            ignore_index=converter.dict["[PAD]"]).to(device)

    if "Pseudo" in opt.semi:
        criterion_SemiSL = PseudoLabelLoss(opt, converter, criterion)
    elif "MeanT" in opt.semi:
        criterion_SemiSL = MeanTeacherLoss(opt, student_for_init_teacher=model)

    # loss averager
    train_loss_avg = Averager()
    semi_loss_avg = Averager()  # semi supervised loss avg

    # filter that only require gradient descent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print(f"Trainable params num: {sum(params_num)}")
    log.write(f"Trainable params num: {sum(params_num)}\n")
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.optimizer == "sgd":
        optimizer = torch.optim.SGD(
            filtered_parameters,
            lr=opt.lr,
            momentum=opt.sgd_momentum,
            weight_decay=opt.sgd_weight_decay,
        )
    elif opt.optimizer == "adadelta":
        optimizer = torch.optim.Adadelta(filtered_parameters,
                                         lr=opt.lr,
                                         rho=opt.rho,
                                         eps=opt.eps)
    elif opt.optimizer == "adam":
        optimizer = torch.optim.Adam(filtered_parameters, lr=opt.lr)
    print("Optimizer:")
    print(optimizer)
    log.write(repr(optimizer) + "\n")

    if "super" in opt.schedule:
        if opt.optimizer == "sgd":
            cycle_momentum = True
        else:
            cycle_momentum = False

        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=opt.lr,
            cycle_momentum=cycle_momentum,
            div_factor=20,
            final_div_factor=1000,
            total_steps=opt.num_iter,
        )
        print("Scheduler:")
        print(scheduler)
        log.write(repr(scheduler) + "\n")
    """ final options """
    # print(opt)
    opt_log = "------------ Options -------------\n"
    args = vars(opt)
    for k, v in args.items():
        if str(k) == "character" and len(str(v)) > 500:
            opt_log += f"{str(k)}: So many characters to show all: number of characters: {len(str(v))}\n"
        else:
            opt_log += f"{str(k)}: {str(v)}\n"
    opt_log += "---------------------------------------\n"
    print(opt_log)
    log.write(opt_log)
    log.close()
    """ start training """
    start_iter = 0
    if opt.saved_model != "":
        try:
            start_iter = int(opt.saved_model.split("_")[-1].split(".")[0])
            print(f"continue to train, start_iter: {start_iter}")
        except:
            pass

    start_time = time.time()
    best_score = -1

    # training loop
    for iteration in tqdm(
            range(start_iter + 1, opt.num_iter + 1),
            total=opt.num_iter,
            position=0,
            leave=True,
    ):
        if "MeanT" in opt.semi:
            image_tensors, image_tensors_ema, labels = train_loader.get_batch_ema(
            )
        else:
            image_tensors, labels = train_loader.get_batch()

        image = image_tensors.to(device)
        labels_index, labels_length = converter.encode(
            labels, batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)

        # default recognition loss part
        if "CTC" in opt.Prediction:
            preds = model(image)
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            preds_log_softmax = preds.log_softmax(2).permute(1, 0, 2)
            loss = criterion(preds_log_softmax, labels_index, preds_size,
                             labels_length)
        else:
            preds = model(image,
                          labels_index[:, :-1])  # align with Attention.forward
            target = labels_index[:, 1:]  # without [SOS] Symbol
            loss = criterion(preds.view(-1, preds.shape[-1]),
                             target.contiguous().view(-1))

        # semi supervised part (SemiSL)
        if "Pseudo" in opt.semi:
            image_unlabel, _ = train_loader_unlabel_semi.get_batch_two_images()
            image_unlabel = image_unlabel.to(device)
            loss_SemiSL = criterion_SemiSL(image_unlabel, model)

            loss = loss + loss_SemiSL
            semi_loss_avg.add(loss_SemiSL)

        elif "MeanT" in opt.semi:
            (
                image_tensors_unlabel,
                image_tensors_unlabel_ema,
            ) = train_loader_unlabel_semi.get_batch_two_images()
            image_unlabel = image_tensors_unlabel.to(device)
            student_input = torch.cat([image, image_unlabel], dim=0)

            image_ema = image_tensors_ema.to(device)
            image_unlabel_ema = image_tensors_unlabel_ema.to(device)
            teacher_input = torch.cat([image_ema, image_unlabel_ema], dim=0)
            loss_SemiSL = criterion_SemiSL(
                student_input=student_input,
                student_logit=preds,
                student=model,
                teacher_input=teacher_input,
                iteration=iteration,
            )

            loss = loss + loss_SemiSL
            semi_loss_avg.add(loss_SemiSL)

        model.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()
        train_loss_avg.add(loss)

        if "super" in opt.schedule:
            scheduler.step()
        else:
            adjust_learning_rate(optimizer, iteration, opt)

        # validation part.
        # To see training progress, we also conduct validation when 'iteration == 1'
        if iteration % opt.val_interval == 0 or iteration == 1:
            # for validation log
            with open(f"./saved_models/{opt.exp_name}/log_train.txt",
                      "a") as log:
                model.eval()
                with torch.no_grad():
                    (
                        valid_loss,
                        current_score,
                        preds,
                        confidence_score,
                        labels,
                        infer_time,
                        length_of_data,
                    ) = validation(model, criterion, valid_loader, converter,
                                   opt)
                model.train()

                # keep best score (accuracy or norm ED) model on valid dataset
                # Do not use this on test datasets. It would be an unfair comparison
                # (training should be done without referring test set).
                if current_score > best_score:
                    best_score = current_score
                    torch.save(
                        model.state_dict(),
                        f"./saved_models/{opt.exp_name}/best_score.pth",
                    )

                # validation log: loss, lr, score (accuracy or norm ED), time.
                lr = optimizer.param_groups[0]["lr"]
                elapsed_time = time.time() - start_time
                valid_log = f"\n[{iteration}/{opt.num_iter}] Train_loss: {train_loss_avg.val():0.5f}, Valid_loss: {valid_loss:0.5f}"
                valid_log += f", Semi_loss: {semi_loss_avg.val():0.5f}\n"
                valid_log += f'{"Current_score":17s}: {current_score:0.2f}, Current_lr: {lr:0.7f}\n'
                valid_log += f'{"Best_score":17s}: {best_score:0.2f}, Infer_time: {infer_time:0.1f}, Elapsed_time: {elapsed_time:0.1f}'

                # show some predicted results
                dashed_line = "-" * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f"{dashed_line}\n{head}\n{dashed_line}\n"
                for gt, pred, confidence in zip(labels[:5], preds[:5],
                                                confidence_score[:5]):
                    if "Attn" in opt.Prediction:
                        gt = gt[:gt.find("[EOS]")]
                        pred = pred[:pred.find("[EOS]")]

                    predicted_result_log += f"{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n"
                predicted_result_log += f"{dashed_line}"
                valid_log = f"{valid_log}\n{predicted_result_log}"
                print(valid_log)
                log.write(valid_log + "\n")

                opt.writer.add_scalar("train/train_loss",
                                      float(f"{train_loss_avg.val():0.5f}"),
                                      iteration)
                opt.writer.add_scalar("train/semi_loss",
                                      float(f"{semi_loss_avg.val():0.5f}"),
                                      iteration)
                opt.writer.add_scalar("train/lr", float(f"{lr:0.7f}"),
                                      iteration)
                opt.writer.add_scalar("train/elapsed_time",
                                      float(f"{elapsed_time:0.1f}"), iteration)
                opt.writer.add_scalar("valid/valid_loss",
                                      float(f"{valid_loss:0.5f}"), iteration)
                opt.writer.add_scalar("valid/current_score",
                                      float(f"{current_score:0.2f}"),
                                      iteration)
                opt.writer.add_scalar("valid/best_score",
                                      float(f"{best_score:0.2f}"), iteration)

                train_loss_avg.reset()
                semi_loss_avg.reset()
    """ Evaluation at the end of training """
    print("Start evaluation on benchmark testset")
    """ keep evaluation model and result logs """
    os.makedirs(f"./result/{opt.exp_name}", exist_ok=True)
    os.makedirs(f"./evaluation_log", exist_ok=True)
    saved_best_model = f"./saved_models/{opt.exp_name}/best_score.pth"
    # os.system(f'cp {saved_best_model} ./result/{opt.exp_name}/')
    model.load_state_dict(torch.load(f"{saved_best_model}"))

    opt.eval_type = "benchmark"
    model.eval()
    with torch.no_grad():
        total_accuracy, eval_data_list, accuracy_list = benchmark_all_eval(
            model, criterion, converter, opt)

    opt.writer.add_scalar("test/total_accuracy",
                          float(f"{total_accuracy:0.2f}"), iteration)
    for eval_data, accuracy in zip(eval_data_list, accuracy_list):
        accuracy = float(accuracy)
        opt.writer.add_scalar(f"test/{eval_data}", float(f"{accuracy:0.2f}"),
                              iteration)

    print(
        f'finished the experiment: {opt.exp_name}, "CUDA_VISIBLE_DEVICES" was {opt.CUDA_VISIBLE_DEVICES}'
    )
Exemple #2
0
def test(opt):
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)
    model = torch.nn.DataParallel(model).to(device)

    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model, map_location=device))
    opt.exp_name = '_'.join(opt.saved_model.split('/')[1:])
    # print(model)
    # return model
    AlignCollate_evaluation = AlignCollate(imgH=opt.imgH,
                                           imgW=opt.imgW,
                                           keep_ratio_with_pad=opt.PAD)
    eval_data, eval_data_log = hierarchical_dataset(root=opt.eval_data,
                                                    opt=opt)
    evaluation_loader = torch.utils.data.DataLoader(
        eval_data,
        batch_size=opt.batch_size,
        shuffle=False,
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_evaluation,
        pin_memory=True)
    # _, accuracy_by_best_model, _, _, _, _, _, _ = validation(
    #     model, criterion, evaluation_loader, converter, opt)

    for i, (image_tensors, labels) in enumerate(evaluation_loader):
        # batch_size = image_tensors.size(0)
        # text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)
        target_layer = model.module.FeatureExtraction.ConvNet.layer4[-1]
        # model.eval()
        # handle =model.module.Transformation.register_forward_hook(hook)
        # model(image_tensors,text_for_pred)
        input_tensor = image_tensors
        # print(labels)
        # print(input_tensor.shape,'input_tensor.shape')
        # handle.remove()
        # print(input_tensor)
        # Create an input tensor image for your model..
        # input_tensor=image_tensors
        # Note: input_tensor can be a batch tensor with several images!
        # print(labels)
        # Construct the CAM object once, and then re-use it on many images:
        cam = EigenCAM(model=model,
                       target_layer=target_layer,
                       use_cuda=opt.use_cuda)

        # If target_category is None, the highest scoring category
        # will be used for every image in the batch.
        # target_category can also be an integer, or a list of different integers
        # for every image in the batch.
        text_for_loss, length_for_loss = converter.encode(
            labels, batch_max_length=opt.batch_max_length)

        target_category = text_for_loss

        # You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.
        grayscale_cam = cam(input_tensor=input_tensor,
                            target_category=target_category)

        # In this example grayscale_cam has only one image in the batch:
        grayscale_cam = grayscale_cam[0, :]
        loader = transforms.ToPILImage()
        sourc_image = loader(image_tensors[0])
        sourc_image = cv2.cvtColor(np.asarray(sourc_image), cv2.COLOR_RGB2BGR)
        # rgb_img=loader(input_tensor[0].cpu())
        # rgb_img.save('rgb_visual.bmp')
        # rgb_img2=cv2.imread('rgb_visual.bmp')
        rgb_img2 = np.float32(sourc_image) / 255
        visualization = show_cam_on_image(rgb_img2, grayscale_cam)
        sourc_image = cv2.resize(sourc_image, (0, 0),
                                 fx=5,
                                 fy=5,
                                 interpolation=cv2.INTER_CUBIC)
        visualization = cv2.resize(visualization, (0, 0),
                                   fx=5,
                                   fy=5,
                                   interpolation=cv2.INTER_CUBIC)
        cat_image = np.vstack((sourc_image, visualization))
        cv2.imwrite('visual_image2/' + str(i) + '_2.bmp', cat_image)
def train(opt):
    """ dataset preparation """
    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
    valid_dataset = hierarchical_dataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=opt.batch_size,
        shuffle=True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True)
    print('-' * 80)

    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
          opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
          opt.SequenceModeling, opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    model = torch.nn.DataParallel(model).to(device)
    model.train()
    if opt.saved_model != '':
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_model), strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_model))
    print("Model:")
    print(model)

    """ setup loss """
    if 'CTC' in opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps)
    print("Optimizer:")
    print(optimizer)

    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)

    """ start training """
    start_iter = 0
    if opt.saved_model != '':
        start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
        print(f'continue to train, start_iter: {start_iter}')

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = 1e+6
    i = start_iter

    while(True):
        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        text, length = converter.encode(labels, batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)

        if 'CTC' in opt.Prediction:
            preds = model(image, text).log_softmax(2)
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            preds = preds.permute(1, 0, 2)

            
            torch.backends.cudnn.enabled = False
            cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device))
            torch.backends.cudnn.enabled = True

            

        else:
            preds = model(image, text[:, :-1])  # align with Attention.forward
            target = text[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))

        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)

        # validation part
        if i % opt.valInterval == 0:
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, criterion, valid_loader, converter, opt)
                model.train()

                # training loss and validation loss
                loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                print(loss_log)
                log.write(loss_log + '\n')
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'
                print(current_model_log)
                log.write(current_model_log + '\n')

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth')
                if current_norm_ED < best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth')
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'
                print(best_model_log)
                log.write(best_model_log + '\n')

                # show some predicted results
                print('-' * 80)
                print(f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F')
                log.write(f'{"Ground Truth":25s} | {"Prediction":25s} | {"Confidence Score"}\n')
                print('-' * 80)
                for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]):
                    if 'Attn' in opt.Prediction:
                        gt = gt[:gt.find('[s]')]
                        pred = pred[:pred.find('[s]')]

                    print(f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}')
                    log.write(f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n')
                print('-' * 80)

        # save model per 1e+5 iter.
        if (i + 1) % 1e+5 == 0:
            torch.save(
                model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')

        if i == opt.num_iter:
            print('end the training')
            sys.exit()
        i += 1
Exemple #4
0
def train(opt, show_number=2, amp=False):
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt',
               'a',
               encoding="utf8")
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD,
                                      contrast_adjust=opt.contrast_adjust)
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=min(32, opt.batch_size),
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        prefetch_factor=512,
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)

    if opt.saved_model != '':
        pretrained_dict = torch.load(opt.saved_model)
        if opt.new_prediction:
            model.Prediction = nn.Linear(
                model.SequenceModeling_output,
                len(pretrained_dict['module.Prediction.weight']))

        model = torch.nn.DataParallel(model).to(device)
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.FT:
            model.load_state_dict(pretrained_dict, strict=False)
        else:
            model.load_state_dict(pretrained_dict)
        if opt.new_prediction:
            model.module.Prediction = nn.Linear(
                model.module.SequenceModeling_output, opt.num_class)
            for name, param in model.module.Prediction.named_parameters():
                if 'bias' in name:
                    init.constant_(param, 0.0)
                elif 'weight' in name:
                    init.kaiming_normal_(param)
            model = model.to(device)
    else:
        # weight initialization
        for name, param in model.named_parameters():
            if 'localization_fc2' in name:
                print(f'Skip {name} as it is already initialized')
                continue
            try:
                if 'bias' in name:
                    init.constant_(param, 0.0)
                elif 'weight' in name:
                    init.kaiming_normal_(param)
            except Exception as e:  # for batchnorm.
                if 'weight' in name:
                    param.data.fill_(1)
                continue
        model = torch.nn.DataParallel(model).to(device)

    model.train()
    print("Model:")
    print(model)
    count_parameters(model)
    """ setup loss """
    if 'CTC' in opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
            device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # freeze some layers
    try:
        if opt.freeze_FeatureFxtraction:
            for param in model.module.FeatureExtraction.parameters():
                param.requires_grad = False
        if opt.freeze_SequenceModeling:
            for param in model.module.SequenceModeling.parameters():
                param.requires_grad = False
    except:
        pass

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.optim == 'adam':
        #optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999))
        optimizer = optim.Adam(filtered_parameters)
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    print(optimizer)
    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.experiment_name}/opt.txt',
              'a',
              encoding="utf8") as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0
    if opt.saved_model != '':
        try:
            start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    i = start_iter

    scaler = GradScaler()
    t1 = time.time()

    while (True):
        # train part
        optimizer.zero_grad(set_to_none=True)

        if amp:
            with autocast():
                image_tensors, labels = train_dataset.get_batch()
                image = image_tensors.to(device)
                text, length = converter.encode(
                    labels, batch_max_length=opt.batch_max_length)
                batch_size = image.size(0)

                if 'CTC' in opt.Prediction:
                    preds = model(image, text).log_softmax(2)
                    preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                    preds = preds.permute(1, 0, 2)
                    torch.backends.cudnn.enabled = False
                    cost = criterion(preds, text.to(device),
                                     preds_size.to(device), length.to(device))
                    torch.backends.cudnn.enabled = True
                else:
                    preds = model(image,
                                  text[:, :-1])  # align with Attention.forward
                    target = text[:, 1:]  # without [GO] Symbol
                    cost = criterion(preds.view(-1, preds.shape[-1]),
                                     target.contiguous().view(-1))
            scaler.scale(cost).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            image_tensors, labels = train_dataset.get_batch()
            image = image_tensors.to(device)
            text, length = converter.encode(
                labels, batch_max_length=opt.batch_max_length)
            batch_size = image.size(0)
            if 'CTC' in opt.Prediction:
                preds = model(image, text).log_softmax(2)
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                preds = preds.permute(1, 0, 2)
                torch.backends.cudnn.enabled = False
                cost = criterion(preds, text.to(device), preds_size.to(device),
                                 length.to(device))
                torch.backends.cudnn.enabled = True
            else:
                preds = model(image,
                              text[:, :-1])  # align with Attention.forward
                target = text[:, 1:]  # without [GO] Symbol
                cost = criterion(preds.view(-1, preds.shape[-1]),
                                 target.contiguous().view(-1))
            cost.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
            optimizer.step()
        loss_avg.add(cost)

        # validation part
        if (i % opt.valInterval == 0) and (i != 0):
            print('training time: ', time.time() - t1)
            t1 = time.time()
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.experiment_name}/log_train.txt',
                      'a',
                      encoding="utf8") as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels,\
                    infer_time, length_of_data = validation(model, criterion, valid_loader, converter, opt, device)
                model.train()

                # training loss and validation loss
                loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.4f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.experiment_name}/best_accuracy.pth'
                    )
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.experiment_name}/best_norm_ED.pth'
                    )
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.4f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'

                #show_number = min(show_number, len(labels))

                start = random.randint(0, len(labels) - show_number)
                for gt, pred, confidence in zip(
                        labels[start:start + show_number],
                        preds[start:start + show_number],
                        confidence_score[start:start + show_number]):
                    if 'Attn' in opt.Prediction:
                        gt = gt[:gt.find('[s]')]
                        pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')
                print('validation time: ', time.time() - t1)
                t1 = time.time()
        # save model per 1e+4 iter.
        if (i + 1) % 1e+4 == 0:
            torch.save(model.state_dict(),
                       f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')

        if i == opt.num_iter:
            print('end the training')
            sys.exit()
        i += 1
def train(opt):
    print(opt.local_rank)
    opt.device = torch.device('cuda:{}'.format(opt.local_rank))
    device = opt.device
    """ dataset preparation """
    train_dataset = Batch_Balanced_Dataset(opt)

    valid_loader = train_dataset.getValDataloader()
    print('-' * 80)
    """ model configuration """
    if 'CTC' == opt.Prediction:
        converter = CTCLabelConverter(opt.character, opt)
    elif 'Attn' == opt.Prediction:
        converter = AttnLabelConverter(opt.character, opt)
    elif 'CTC_Attn' == opt.Prediction:
        converter = CTCLabelConverter(opt.character, opt), AttnLabelConverter(
            opt.character, opt)

    opt.num_class = len(opt.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)

    model.to(opt.device)

    print(model)
    print('model input parameters', opt.rgb, opt.imgH, opt.imgW,
          opt.num_fiducial, opt.input_channel, opt.output_channel,
          opt.hidden_size, opt.num_class, opt.batch_max_length,
          opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling,
          opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue
    """ setup loss """
    if 'CTC' == opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    elif 'Attn' == opt.Prediction:
        criterion = torch.nn.CrossEntropyLoss(
            ignore_index=0).to(device), torch.nn.MSELoss(
                reduction="sum").to(device)
        # ignore [GO] token = ignore index 0
    elif 'CTC_Attn' == opt.Prediction:
        criterion = torch.nn.CTCLoss(
            zero_infinity=True).to(device), torch.nn.CrossEntropyLoss(
                ignore_index=0).to(device), torch.nn.MSELoss(
                    reduction='sum').to(device)

    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))

    if opt.local_rank == 0:
        print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.sgd:
        optimizer = optim.SGD(filtered_parameters,
                              lr=opt.lr,
                              momentum=0.9,
                              weight_decay=opt.weight_decay)
    elif opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    if opt.local_rank == 0:
        print("Optimizer:")
        print(optimizer)

    if opt.sync_bn:
        model = apex.parallel.convert_syncbn_model(model)

    if opt.amp > 1:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level="O" + str(opt.amp),
                                          keep_batchnorm_fp32=True,
                                          loss_scale="dynamic")
    else:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level="O" + str(opt.amp))

    # data parallel for multi-GPU
    model = DDP(model)

    if opt.continue_model != '':
        print(f'loading pretrained model from {opt.continue_model}')
        try:
            model.load_state_dict(
                torch.load(opt.continue_model,
                           map_location=torch.device(
                               'cuda', torch.cuda.current_device())))
        except:
            traceback.print_exc()
            print(f'COPYING pretrained model from {opt.continue_model}')
            pretrained_dict = torch.load(opt.continue_model,
                                         map_location=torch.device(
                                             'cuda',
                                             torch.cuda.current_device()))

            model_dict = model.state_dict()
            pretrained_dict2 = dict()
            for k, v in pretrained_dict.items():
                if opt.Prediction == 'Attn':
                    if 'module.Prediction_attn.' in k:
                        k = k.replace('module.Prediction_attn.',
                                      'module.Prediction.')
                if k in model_dict and model_dict[k].shape == v.shape:
                    pretrained_dict2[k] = v

            model_dict.update(pretrained_dict2)
            model.load_state_dict(model_dict)

    model.train()
    """ final options """
    with open(f'./saved_models/{opt.experiment_name}/opt.txt',
              'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        opt_log += str(model)
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = 1e+6
    i = start_iter

    ct = opt.batch_mul
    model.zero_grad()

    dist.barrier()
    while (True):
        # train part
        start = time.time()
        image, labels, pos = train_dataset.sync_get_batch()
        end = time.time()
        data_t = end - start

        start = time.time()
        batch_size = image.size(0)

        if 'CTC' == opt.Prediction:
            text, length = converter.encode(
                labels, batch_max_length=opt.batch_max_length)
            preds = model(image, text).log_softmax(2)
            preds_size = torch.IntTensor([preds.size(1)] *
                                         batch_size).to(device)
            preds = preds.permute(1, 0, 2)  # to use CTCLoss format

            # To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
            # https://github.com/jpuigcerver/PyLaia/issues/16
            torch.backends.cudnn.enabled = False
            cost = criterion(preds, text, preds_size, length)
            torch.backends.cudnn.enabled = True
        elif 'Attn' == opt.Prediction:
            text, length = converter.encode(
                labels, batch_max_length=opt.batch_max_length)
            preds = model(image, text[:, :-1])  # align with Attention.forward
            preds_attn = preds[0]
            preds_alpha = preds[1]
            target = text[:, 1:]  # without [GO] Symbol
            cost = criterion[0](preds_attn.view(-1, preds_attn.shape[-1]),
                                target.contiguous().view(-1))

            if opt.posreg_w > 0.001:
                cost_pos = alpha_loss(preds_alpha, pos, opt, criterion[1])
                print('attn_cost = ', cost, 'pos_cost = ',
                      cost_pos * opt.posreg_w)
                cost += opt.posreg_w * cost_pos
            else:
                print('attn_cost = ', cost_attn)

        elif 'CTC_Attn' == opt.Prediction:
            text_ctc, length_ctc = converter[0].encode(
                labels, batch_max_length=opt.batch_max_length)
            text_attn, length_attn = converter[1].encode(
                labels, batch_max_length=opt.batch_max_length)
            """ ctc prediction and loss """

            #should input text_attn here
            preds = model(image, text_attn[:, :-1])
            preds_ctc = preds[0].log_softmax(2)

            preds_ctc_size = torch.IntTensor([preds_ctc.size(1)] *
                                             batch_size).to(device)
            preds_ctc = preds_ctc.permute(1, 0, 2)  # to use CTCLoss format

            # To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
            # https://github.com/jpuigcerver/PyLaia/issues/16

            torch.backends.cudnn.enabled = False
            cost_ctc = criterion[0](preds_ctc, text_ctc, preds_ctc_size,
                                    length_ctc)
            torch.backends.cudnn.enabled = True
            """ attention prediction and loss """
            preds_attn = preds[1][0]  # align with Attention.forward
            preds_alpha = preds[1][1]

            target = text_attn[:, 1:]  # without [GO] Symbol

            cost_attn = criterion[1](preds_attn.view(-1, preds_attn.shape[-1]),
                                     target.contiguous().view(-1))

            cost = opt.ctc_attn_loss_ratio * cost_ctc + (
                1 - opt.ctc_attn_loss_ratio) * cost_attn

            if opt.posreg_w > 0.001:
                cost_pos = alpha_loss(preds_alpha, pos, opt, criterion[2])
                cost += opt.posreg_w * cost_pos
                cost_ctc = reduce_tensor(cost_ctc)
                cost_attn = reduce_tensor(cost_attn)
                cost_pos = reduce_tensor(cost_pos)
                if opt.local_rank == 0:
                    print('ctc_cost = ', cost_ctc, 'attn_cost = ', cost_attn,
                          'pos_cost = ', cost_pos * opt.posreg_w)
            else:
                cost_ctc = reduce_tensor(cost_ctc)
                cost_attn = reduce_tensor(cost_attn)
                if opt.local_rank == 0:
                    print('ctc_cost = ', cost_ctc, 'attn_cost = ', cost_attn)

        cost /= opt.batch_mul
        if opt.amp > 0:
            with amp.scale_loss(cost, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            cost.backward()
        """ https://github.com/davidlmorton/learning-rate-schedules/blob/master/increasing_batch_size_without_increasing_memory.ipynb """
        ct -= 1
        if ct == 0:
            if opt.amp > 0:
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               opt.grad_clip)
            else:
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(),
                    opt.grad_clip)  # gradient clipping with 5 (Default)
            optimizer.step()
            model.zero_grad()
            ct = opt.batch_mul
        else:
            continue

        train_t = time.time() - start
        cost = reduce_tensor(cost)
        loss_avg.add(cost)
        if opt.local_rank == 0:
            print('iter', i, 'loss =', cost, ', data_t=', data_t, ',train_t=',
                  train_t, ', batchsz=', opt.batch_mul * opt.batch_size)
        sys.stdout.flush()
        # validation part
        if (i > 0 and i % opt.valInterval == 0) or (i == 0 and
                                                    opt.continue_model != ''):
            elapsed_time = time.time() - start_time
            print(
                f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}'
            )
            # for log
            with open(f'./saved_models/{opt.experiment_name}/log_train.txt',
                      'a') as log:
                log.write(
                    f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n'
                )
                loss_avg.reset()

                model.eval()
                with torch.no_grad():
                    if 'CTC_Attn' in opt.Prediction:
                        # we only count for attention accuracy, because ctc is used to help attention
                        valid_loss, current_accuracy_ctc, current_accuracy, current_norm_ED_ctc, current_norm_ED, preds, labels, infer_time, length_of_data = validation(
                            model, criterion[1], valid_loader, converter[1],
                            opt, converter[0])
                    elif 'Attn' in opt.Prediction:
                        valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation(
                            model, criterion[0], valid_loader, converter, opt)
                    else:
                        valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation(
                            model, criterion, valid_loader, converter, opt)
                model.train()

                for pred, gt in zip(preds[:10], labels[:10]):
                    if 'Attn' in opt.Prediction:
                        pred = pred[:pred.find('[s]')]
                        gt = gt[:gt.find('[s]')]
                    print(f'{pred:20s}, gt: {gt:20s},   {str(pred == gt)}')
                    log.write(
                        f'{pred:20s}, gt: {gt:20s},   {str(pred == gt)}\n')

                valid_log = f'[{i}/{opt.num_iter}] valid loss: {valid_loss:0.5f}'
                valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}'

                if 'CTC_Attn' in opt.Prediction:
                    valid_log += f' ctc_accuracy: {current_accuracy_ctc:0.3f}, ctc_norm_ED: {current_norm_ED_ctc:0.2f}'
                    current_accuracy = max(current_accuracy,
                                           current_accuracy_ctc)
                    current_norm_ED = min(current_norm_ED, current_norm_ED_ctc)

                if opt.local_rank == 0:
                    print(valid_log)
                    log.write(valid_log + '\n')

                    # keep best accuracy model
                    if current_accuracy > best_accuracy:
                        best_accuracy = current_accuracy
                        torch.save(
                            model.state_dict(),
                            f'./saved_models/{opt.experiment_name}/best_accuracy.pth'
                        )
                        torch.save(
                            model,
                            f'./saved_models/{opt.experiment_name}/best_accuracy.model'
                        )
                    if current_norm_ED < best_norm_ED:
                        best_norm_ED = current_norm_ED
                        torch.save(
                            model.state_dict(),
                            f'./saved_models/{opt.experiment_name}/best_norm_ED.pth'
                        )
                        torch.save(
                            model,
                            f'./saved_models/{opt.experiment_name}/best_norm_ED.model'
                        )
                    best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}'
                    print(best_model_log)
                    log.write(best_model_log + '\n')

        # save model per iter.
        if (i + 1) % opt.save_interval == 0 and opt.local_rank == 0:
            torch.save(model.state_dict(),
                       f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')

        if i == opt.num_iter:
            print('end the training')
            sys.exit()
        if opt.prof_iter > 0 and i > opt.prof_iter:
            sys.exit()

        i += 1
Exemple #6
0
def train(opt):
    plotDir = os.path.join(opt.exp_dir, opt.exp_name, 'plots')
    if not os.path.exists(plotDir):
        os.makedirs(plotDir)

    lib.print_model_settings(locals().copy())
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')

    #considering the real images for discriminator
    opt.batch_size = opt.batch_size * 2

    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(os.path.join(opt.exp_dir, opt.exp_name, 'log_dataset.txt'), 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        False,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3

    model = AdaINGenV4(opt)
    ocrModel = Model(opt)
    disModel = MsImageDisV1(opt)

    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)

    #  weight initialization
    for currModel in [model, ocrModel, disModel]:
        for name, param in currModel.named_parameters():
            if 'localization_fc2' in name:
                print(f'Skip {name} as it is already initialized')
                continue
            try:
                if 'bias' in name:
                    init.constant_(param, 0.0)
                elif 'weight' in name:
                    init.kaiming_normal_(param)
            except Exception as e:  # for batchnorm.
                if 'weight' in name:
                    param.data.fill_(1)
                continue

    # data parallel for multi-GPU
    ocrModel = torch.nn.DataParallel(ocrModel).to(device)
    if not opt.ocrFixed:
        ocrModel.train()
    else:
        ocrModel.module.Transformation.eval()
        ocrModel.module.FeatureExtraction.eval()
        ocrModel.module.AdaptiveAvgPool.eval()
        # ocrModel.module.SequenceModeling.eval()
        ocrModel.module.Prediction.eval()

    model = torch.nn.DataParallel(model).to(device)
    model.train()

    disModel = torch.nn.DataParallel(disModel).to(device)
    disModel.train()

    if opt.modelFolderFlag:

        if len(
                glob.glob(
                    os.path.join(opt.exp_dir, opt.exp_name,
                                 "iter_*_synth.pth"))) > 0:
            opt.saved_synth_model = glob.glob(
                os.path.join(opt.exp_dir, opt.exp_name,
                             "iter_*_synth.pth"))[-1]

        if len(
                glob.glob(
                    os.path.join(opt.exp_dir, opt.exp_name,
                                 "iter_*_dis.pth"))) > 0:
            opt.saved_dis_model = glob.glob(
                os.path.join(opt.exp_dir, opt.exp_name, "iter_*_dis.pth"))[-1]

    #loading pre-trained model
    if opt.saved_ocr_model != '' and opt.saved_ocr_model != 'None':
        print(f'loading pretrained ocr model from {opt.saved_ocr_model}')
        if opt.FT:
            ocrModel.load_state_dict(torch.load(opt.saved_ocr_model),
                                     strict=False)
        else:
            ocrModel.load_state_dict(torch.load(opt.saved_ocr_model))
    print("OCRModel:")
    print(ocrModel)

    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        print(f'loading pretrained synth model from {opt.saved_synth_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_synth_model),
                                  strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_synth_model))
    print("SynthModel:")
    print(model)

    if opt.saved_dis_model != '' and opt.saved_dis_model != 'None':
        print(
            f'loading pretrained discriminator model from {opt.saved_dis_model}'
        )
        if opt.FT:
            disModel.load_state_dict(torch.load(opt.saved_dis_model),
                                     strict=False)
        else:
            disModel.load_state_dict(torch.load(opt.saved_dis_model))
    print("DisModel:")
    print(disModel)
    """ setup loss """
    if 'CTC' in opt.Prediction:
        ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
            device)  # ignore [GO] token = ignore index 0

    if opt.imgReconLoss == 'l1':
        recCriterion = torch.nn.L1Loss()
    elif opt.imgReconLoss == 'ssim':
        recCriterion = ssim
    elif opt.imgReconLoss == 'ms-ssim':
        recCriterion = msssim

    if opt.styleLoss == 'l1':
        styleRecCriterion = torch.nn.L1Loss()
    elif opt.styleLoss == 'triplet':
        styleRecCriterion = torch.nn.TripletMarginLoss(
            margin=opt.tripletMargin, p=1)
    #for validation; check only positive pairs
    styleTestRecCriterion = torch.nn.L1Loss()

    # loss averager
    loss_avg_ocr = Averager()
    loss_avg = Averager()
    loss_avg_dis = Averager()

    loss_avg_ocrRecon_1 = Averager()
    loss_avg_ocrRecon_2 = Averager()
    loss_avg_gen = Averager()
    loss_avg_imgRecon = Averager()
    loss_avg_styRecon = Averager()

    ##---------------------------------------##
    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.optim == 'adam':
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, opt.beta2),
                               weight_decay=opt.weight_decay)
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps,
                                   weight_decay=opt.weight_decay)
    print("SynthOptimizer:")
    print(optimizer)

    #filter parameters for OCR training
    ocr_filtered_parameters = []
    ocr_params_num = []
    for p in filter(lambda p: p.requires_grad, ocrModel.parameters()):
        ocr_filtered_parameters.append(p)
        ocr_params_num.append(np.prod(p.size()))
    print('OCR Trainable params num : ', sum(ocr_params_num))

    # setup optimizer
    if opt.optim == 'adam':
        ocr_optimizer = optim.Adam(ocr_filtered_parameters,
                                   lr=opt.lr,
                                   betas=(opt.beta1, opt.beta2),
                                   weight_decay=opt.weight_decay)
    else:
        ocr_optimizer = optim.Adadelta(ocr_filtered_parameters,
                                       lr=opt.lr,
                                       rho=opt.rho,
                                       eps=opt.eps,
                                       weight_decay=opt.weight_decay)
    print("OCROptimizer:")
    print(ocr_optimizer)

    #filter parameters for OCR training
    dis_filtered_parameters = []
    dis_params_num = []
    for p in filter(lambda p: p.requires_grad, disModel.parameters()):
        dis_filtered_parameters.append(p)
        dis_params_num.append(np.prod(p.size()))
    print('Dis Trainable params num : ', sum(dis_params_num))

    # setup optimizer
    if opt.optim == 'adam':
        dis_optimizer = optim.Adam(dis_filtered_parameters,
                                   lr=opt.lr,
                                   betas=(opt.beta1, opt.beta2),
                                   weight_decay=opt.weight_decay)
    else:
        dis_optimizer = optim.Adadelta(dis_filtered_parameters,
                                       lr=opt.lr,
                                       rho=opt.rho,
                                       eps=opt.eps,
                                       weight_decay=opt.weight_decay)
    print("DisOptimizer:")
    print(dis_optimizer)
    ##---------------------------------------##
    """ final options """
    with open(os.path.join(opt.exp_dir, opt.exp_name, 'opt.txt'),
              'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0

    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        try:
            start_iter = int(
                opt.saved_synth_model.split('_')[-2].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    #get schedulers
    scheduler = get_scheduler(optimizer, opt)
    ocr_scheduler = get_scheduler(ocr_optimizer, opt)
    dis_scheduler = get_scheduler(dis_optimizer, opt)

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    best_accuracy_ocr = -1
    best_norm_ED_ocr = -1
    iteration = start_iter
    cntr = 0

    while (True):
        # train part

        if opt.lr_policy != "None":
            scheduler.step()
            ocr_scheduler.step()
            dis_scheduler.step()

        image_tensors_all, labels_1_all, labels_2_all = train_dataset.get_batch(
        )

        # ## comment
        # pdb.set_trace()
        # for imgCntr in range(image_tensors.shape[0]):
        #     save_image(tensor2im(image_tensors[imgCntr]),'temp/'+str(imgCntr)+'.png')
        # pdb.set_trace()
        # ###
        # print(cntr)
        cntr += 1
        disCnt = int(image_tensors_all.size(0) / 2)
        image_tensors, image_tensors_real, labels_gt, labels_2 = image_tensors_all[:disCnt], image_tensors_all[
            disCnt:disCnt +
            disCnt], labels_1_all[:disCnt], labels_2_all[:disCnt]

        image_hole_tensors, image_mask_tensors = genRandomMasks(image_tensors)

        image = image_tensors.to(device)
        image_hole_tensors = image_hole_tensors.to(device)
        image_mask_tensors = image_mask_tensors.to(device)
        image_real = image_tensors_real.to(device)
        batch_size = image.size(0)

        ##-----------------------------------##
        #generate text(labels) from ocr.forward
        if opt.ocrFixed:
            # ocrModel.eval()
            length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                              batch_size).to(device)
            text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                             1).fill_(0).to(device)

            if 'CTC' in opt.Prediction:
                preds = ocrModel(image, text_for_pred)
                preds = preds[:, :text_for_loss.shape[1] - 1, :]
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                _, preds_index = preds.max(2)
                labels_1 = converter.decode(preds_index.data, preds_size.data)
            else:
                preds = ocrModel(image, text_for_pred, is_train=False)
                _, preds_index = preds.max(2)
                labels_1 = converter.decode(preds_index, length_for_pred)
                for idx, pred in enumerate(labels_1):
                    pred_EOS = pred.find('[s]')
                    labels_1[
                        idx] = pred[:
                                    pred_EOS]  # prune after "end of sentence" token ([s])
            # ocrModel.train()
        else:
            labels_1 = labels_gt

        ##-----------------------------------##

        text_1, length_1 = converter.encode(
            labels_1, batch_max_length=opt.batch_max_length)
        text_2, length_2 = converter.encode(
            labels_2, batch_max_length=opt.batch_max_length)

        #forward pass from style and word generator
        # images_recon_1, images_recon_2, style = model(image, text_1, text_2)
        images_recon_1, images_recon_2, style = model(image_hole_tensors,
                                                      text_1, text_2)

        if 'CTC' in opt.Prediction:

            if not opt.ocrFixed:
                #ocr training with orig image
                preds_ocr = ocrModel(image, text_1)
                preds_size_ocr = torch.IntTensor([preds_ocr.size(1)] *
                                                 batch_size)
                preds_ocr = preds_ocr.log_softmax(2).permute(1, 0, 2)

                ocrCost_train = ocrCriterion(preds_ocr, text_1, preds_size_ocr,
                                             length_1)

            #content loss for reconstructed images
            preds_1 = ocrModel(images_recon_1, text_1)
            preds_size_1 = torch.IntTensor([preds_1.size(1)] * batch_size)
            preds_1 = preds_1.log_softmax(2).permute(1, 0, 2)

            preds_2 = ocrModel(images_recon_2, text_2)
            preds_size_2 = torch.IntTensor([preds_2.size(1)] * batch_size)
            preds_2 = preds_2.log_softmax(2).permute(1, 0, 2)
            ocrCost_1 = ocrCriterion(preds_1, text_1, preds_size_1, length_1)
            ocrCost_2 = ocrCriterion(preds_2, text_2, preds_size_2, length_2)
            # ocrCost = 0.5*( ocrCost_1 + ocrCost_2 )

        else:
            if not opt.ocrFixed:
                #ocr training with orig image
                preds_ocr = ocrModel(
                    image, text_1[:, :-1])  # align with Attention.forward
                target_ocr = text_1[:, 1:]  # without [GO] Symbol

                ocrCost_train = ocrCriterion(
                    preds_ocr.view(-1, preds_ocr.shape[-1]),
                    target_ocr.contiguous().view(-1))

            #content loss for reconstructed images
            preds_1 = ocrModel(images_recon_1, text_1[:, :-1],
                               is_train=False)  # align with Attention.forward
            target_1 = text_1[:, 1:]  # without [GO] Symbol

            preds_2 = ocrModel(images_recon_2, text_2[:, :-1],
                               is_train=False)  # align with Attention.forward
            target_2 = text_2[:, 1:]  # without [GO] Symbol

            ocrCost_1 = ocrCriterion(preds_1.view(-1, preds_1.shape[-1]),
                                     target_1.contiguous().view(-1))
            ocrCost_2 = ocrCriterion(preds_2.view(-1, preds_2.shape[-1]),
                                     target_2.contiguous().view(-1))
            # ocrCost = 0.5*(ocrCost_1+ocrCost_2)

        if not opt.ocrFixed:
            #training OCR
            ocrModel.zero_grad()
            ocrCost_train.backward()
            # torch.nn.utils.clip_grad_norm_(ocrModel.parameters(), opt.grad_clip)  # gradient clipping with 5 (Default)
            ocr_optimizer.step()
            #if ocr is fixed; ignore this loss
            loss_avg_ocr.add(ocrCost_train)
        else:
            loss_avg_ocr.add(torch.tensor(0.0))

        #Domain discriminator: Dis update
        disModel.zero_grad()
        disCost = opt.disWeight * 0.5 * (
            disModel.module.calc_dis_loss(images_recon_1.detach(), image_real)
            + disModel.module.calc_dis_loss(images_recon_2.detach(), image))
        disCost.backward()
        # torch.nn.utils.clip_grad_norm_(disModel.parameters(), opt.grad_clip)  # gradient clipping with 5 (Default)
        dis_optimizer.step()
        loss_avg_dis.add(disCost)

        # #[Style Encoder] + [Word Generator] update
        #Adversarial loss
        disGenCost = 0.5 * (disModel.module.calc_gen_loss(images_recon_1) +
                            disModel.module.calc_gen_loss(images_recon_2))

        #Input image reconstruction loss
        if opt.imgReconLoss == 'ssim':
            recCost = -1 * recCriterion(images_recon_1, image, val_range=2)
        elif opt.imgReconLoss == 'ms-ssim':
            recCost = -1 * recCriterion(
                images_recon_1, image, val_range=2, normalize='relu')
        else:
            recCost = 0.5 * (recCriterion(image_mask_tensors * images_recon_1,
                                          image_mask_tensors * image) +
                             recCriterion(
                                 (1 - image_mask_tensors) * images_recon_1,
                                 (1 - image_mask_tensors) * image))

        #Pair style reconstruction loss
        if opt.styleReconWeight == 0.0:
            styleRecCost = torch.tensor(0.0)
        else:
            if opt.styleLoss == 'l1':
                if opt.styleDetach:
                    styleRecCost = styleRecCriterion(
                        model(images_recon_2, None, None, styleFlag=True),
                        style.detach())
                else:
                    styleRecCost = styleRecCriterion(
                        model(images_recon_2, None, None, styleFlag=True),
                        style)
            elif opt.styleLoss == 'triplet':
                if opt.styleDetach:
                    styleRecCost = styleRecCriterion(
                        model(images_recon_2, None, None, styleFlag=True),
                        style.detach(),
                        model(image_real, None, None, styleFlag=True))
                else:
                    styleRecCost = styleRecCriterion(
                        model(images_recon_2, None, None, styleFlag=True),
                        style, model(image_real, None, None, styleFlag=True))

        #OCR Content cost
        ocrCost = 0.5 * (opt.ocrWeight_1 * ocrCost_1 +
                         opt.ocrWeight_2 * ocrCost_2)

        cost = opt.ocrWeight * ocrCost + opt.reconWeight * recCost + opt.disWeight * disGenCost + opt.styleReconWeight * styleRecCost

        model.zero_grad()
        ocrModel.zero_grad()
        disModel.zero_grad()
        cost.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()
        loss_avg.add(cost)

        #Individual losses
        loss_avg_ocrRecon_1.add(opt.ocrWeight * 0.5 * ocrCost_1)
        loss_avg_ocrRecon_2.add(opt.ocrWeight * 0.5 * ocrCost_2)
        loss_avg_gen.add(opt.disWeight * disGenCost)
        loss_avg_imgRecon.add(opt.reconWeight * recCost)
        loss_avg_styRecon.add(opt.styleReconWeight * styleRecCost)

        # validation part
        if (
                iteration + 1
        ) % opt.valInterval == 0 or iteration == 0:  # To see training progress, we also conduct validation when 'iteration == 0'

            #Save training images
            os.makedirs(os.path.join(opt.exp_dir, opt.exp_name, 'trainImages',
                                     str(iteration)),
                        exist_ok=True)
            for trImgCntr in range(batch_size):
                try:
                    save_image(
                        tensor2im(image[trImgCntr].detach()),
                        os.path.join(
                            opt.exp_dir, opt.exp_name, 'trainImages',
                            str(iteration),
                            str(trImgCntr) + '_input_' + labels_gt[trImgCntr] +
                            '.png'))
                    save_image(
                        tensor2im(images_recon_1[trImgCntr].detach()),
                        os.path.join(
                            opt.exp_dir, opt.exp_name, 'trainImages',
                            str(iteration),
                            str(trImgCntr) + '_recon_' + labels_1[trImgCntr] +
                            '.png'))
                    save_image(
                        tensor2im(images_recon_2[trImgCntr].detach()),
                        os.path.join(
                            opt.exp_dir, opt.exp_name, 'trainImages',
                            str(iteration),
                            str(trImgCntr) + '_pair_' + labels_2[trImgCntr] +
                            '.png'))
                except:
                    print('Warning while saving training image')

            elapsed_time = time.time() - start_time
            # for log

            with open(os.path.join(opt.exp_dir, opt.exp_name, 'log_train.txt'),
                      'a') as log:
                model.eval()
                ocrModel.module.Transformation.eval()
                ocrModel.module.FeatureExtraction.eval()
                ocrModel.module.AdaptiveAvgPool.eval()
                ocrModel.module.SequenceModeling.eval()
                ocrModel.module.Prediction.eval()
                disModel.eval()

                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation_synth_lrw_res(
                        iteration, model, ocrModel, disModel, recCriterion,
                        styleTestRecCriterion, ocrCriterion, valid_loader,
                        converter, opt)
                model.train()
                if not opt.ocrFixed:
                    ocrModel.train()
                else:
                    #     ocrModel.module.Transformation.eval()
                    #     ocrModel.module.FeatureExtraction.eval()
                    #     ocrModel.module.AdaptiveAvgPool.eval()
                    ocrModel.module.SequenceModeling.train()
                #     ocrModel.module.Prediction.eval()

                disModel.train()

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}] Train OCR loss: {loss_avg_ocr.val():0.5f}, Train Synth loss: {loss_avg.val():0.5f}, Train Dis loss: {loss_avg_dis.val():0.5f}, Valid OCR loss: {valid_loss[0]:0.5f}, Valid Synth loss: {valid_loss[1]:0.5f}, Valid Dis loss: {valid_loss[2]:0.5f}, Elapsed_time: {elapsed_time:0.5f}'

                current_model_log_ocr = f'{"Current_accuracy_OCR":17s}: {current_accuracy[0]:0.3f}, {"Current_norm_ED_OCR":17s}: {current_norm_ED[0]:0.2f}'
                current_model_log_1 = f'{"Current_accuracy_recon":17s}: {current_accuracy[1]:0.3f}, {"Current_norm_ED_recon":17s}: {current_norm_ED[1]:0.2f}'
                current_model_log_2 = f'{"Current_accuracy_pair":17s}: {current_accuracy[2]:0.3f}, {"Current_norm_ED_pair":17s}: {current_norm_ED[2]:0.2f}'

                #plotting
                lib.plot.plot(os.path.join(plotDir, 'Train-OCR-Loss'),
                              loss_avg_ocr.val().item())
                lib.plot.plot(os.path.join(plotDir, 'Train-Synth-Loss'),
                              loss_avg.val().item())
                lib.plot.plot(os.path.join(plotDir, 'Train-Dis-Loss'),
                              loss_avg_dis.val().item())

                lib.plot.plot(os.path.join(plotDir, 'Train-OCR-Recon1-Loss'),
                              loss_avg_ocrRecon_1.val().item())
                lib.plot.plot(os.path.join(plotDir, 'Train-OCR-Recon2-Loss'),
                              loss_avg_ocrRecon_2.val().item())
                lib.plot.plot(os.path.join(plotDir, 'Train-Gen-Loss'),
                              loss_avg_gen.val().item())
                lib.plot.plot(os.path.join(plotDir, 'Train-ImgRecon1-Loss'),
                              loss_avg_imgRecon.val().item())
                lib.plot.plot(os.path.join(plotDir, 'Train-StyRecon2-Loss'),
                              loss_avg_styRecon.val().item())

                lib.plot.plot(os.path.join(plotDir, 'Valid-OCR-Loss'),
                              valid_loss[0].item())
                lib.plot.plot(os.path.join(plotDir, 'Valid-Synth-Loss'),
                              valid_loss[1].item())
                lib.plot.plot(os.path.join(plotDir, 'Valid-Dis-Loss'),
                              valid_loss[2].item())

                lib.plot.plot(os.path.join(plotDir, 'Valid-OCR-Recon1-Loss'),
                              valid_loss[3].item())
                lib.plot.plot(os.path.join(plotDir, 'Valid-OCR-Recon2-Loss'),
                              valid_loss[4].item())
                lib.plot.plot(os.path.join(plotDir, 'Valid-Gen-Loss'),
                              valid_loss[5].item())
                lib.plot.plot(os.path.join(plotDir, 'Valid-ImgRecon1-Loss'),
                              valid_loss[6].item())
                lib.plot.plot(os.path.join(plotDir, 'Valid-StyRecon2-Loss'),
                              valid_loss[7].item())

                lib.plot.plot(os.path.join(plotDir, 'Orig-OCR-WordAccuracy'),
                              current_accuracy[0])
                lib.plot.plot(os.path.join(plotDir, 'Recon-OCR-WordAccuracy'),
                              current_accuracy[1])
                lib.plot.plot(os.path.join(plotDir, 'Pair-OCR-WordAccuracy'),
                              current_accuracy[2])

                lib.plot.plot(os.path.join(plotDir, 'Orig-OCR-CharAccuracy'),
                              current_norm_ED[0])
                lib.plot.plot(os.path.join(plotDir, 'Recon-OCR-CharAccuracy'),
                              current_norm_ED[1])
                lib.plot.plot(os.path.join(plotDir, 'Pair-OCR-CharAccuracy'),
                              current_norm_ED[2])

                # keep best accuracy model (on valid dataset)
                if current_accuracy[1] > best_accuracy:
                    best_accuracy = current_accuracy[1]
                    torch.save(
                        model.state_dict(),
                        os.path.join(opt.exp_dir, opt.exp_name,
                                     'best_accuracy.pth'))
                    torch.save(
                        disModel.state_dict(),
                        os.path.join(opt.exp_dir, opt.exp_name,
                                     'best_accuracy_dis.pth'))
                if current_norm_ED[1] > best_norm_ED:
                    best_norm_ED = current_norm_ED[1]
                    torch.save(
                        model.state_dict(),
                        os.path.join(opt.exp_dir, opt.exp_name,
                                     'best_norm_ED.pth'))
                    torch.save(
                        disModel.state_dict(),
                        os.path.join(opt.exp_dir, opt.exp_name,
                                     'best_norm_ED_dis.pth'))
                best_model_log = f'{"Best_accuracy_Recon":17s}: {best_accuracy:0.3f}, {"Best_norm_ED_Recon":17s}: {best_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy[0] > best_accuracy_ocr:
                    best_accuracy_ocr = current_accuracy[0]
                    if not opt.ocrFixed:
                        torch.save(
                            ocrModel.state_dict(),
                            os.path.join(opt.exp_dir, opt.exp_name,
                                         'best_accuracy_ocr.pth'))
                if current_norm_ED[0] > best_norm_ED_ocr:
                    best_norm_ED_ocr = current_norm_ED[0]
                    if not opt.ocrFixed:
                        torch.save(
                            ocrModel.state_dict(),
                            os.path.join(opt.exp_dir, opt.exp_name,
                                         'best_norm_ED_ocr.pth'))
                best_model_log_ocr = f'{"Best_accuracy_ocr":17s}: {best_accuracy_ocr:0.3f}, {"Best_norm_ED_ocr":17s}: {best_norm_ED_ocr:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log_ocr}\n{current_model_log_1}\n{current_model_log_2}\n{best_model_log_ocr}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":32s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'

                for gt_ocr, pred_ocr, confidence_ocr, gt_1, pred_1, confidence_1, gt_2, pred_2, confidence_2 in zip(
                        labels[0][:5], preds[0][:5], confidence_score[0][:5],
                        labels[1][:5], preds[1][:5], confidence_score[1][:5],
                        labels[2][:5], preds[2][:5], confidence_score[2][:5]):
                    if 'Attn' in opt.Prediction:
                        # gt_ocr = gt_ocr[:gt_ocr.find('[s]')]
                        pred_ocr = pred_ocr[:pred_ocr.find('[s]')]

                        # gt_1 = gt_1[:gt_1.find('[s]')]
                        pred_1 = pred_1[:pred_1.find('[s]')]

                        # gt_2 = gt_2[:gt_2.find('[s]')]
                        pred_2 = pred_2[:pred_2.find('[s]')]

                    predicted_result_log += f'{"ocr"}: {gt_ocr:27s} | {pred_ocr:25s} | {confidence_ocr:0.4f}\t{str(pred_ocr == gt_ocr)}\n'
                    predicted_result_log += f'{"recon"}: {gt_1:25s} | {pred_1:25s} | {confidence_1:0.4f}\t{str(pred_1 == gt_1)}\n'
                    predicted_result_log += f'{"pair"}: {gt_2:26s} | {pred_2:25s} | {confidence_2:0.4f}\t{str(pred_2 == gt_2)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')

                loss_avg_ocr.reset()
                loss_avg.reset()
                loss_avg_dis.reset()

                loss_avg_ocrRecon_1.reset()
                loss_avg_ocrRecon_2.reset()
                loss_avg_gen.reset()
                loss_avg_imgRecon.reset()
                loss_avg_styRecon.reset()

            lib.plot.flush()

        lib.plot.tick()

        # save model per 1e+5 iter.
        if (iteration) % 1e+5 == 0:
            torch.save(
                model.state_dict(),
                os.path.join(opt.exp_dir, opt.exp_name,
                             'iter_' + str(iteration + 1) + '_synth.pth'))
            if not opt.ocrFixed:
                torch.save(
                    ocrModel.state_dict(),
                    os.path.join(opt.exp_dir, opt.exp_name,
                                 'iter_' + str(iteration + 1) + '_ocr.pth'))
            torch.save(
                disModel.state_dict(),
                os.path.join(opt.exp_dir, opt.exp_name,
                             'iter_' + str(iteration + 1) + '_dis.pth'))

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
Exemple #7
0
def train(opt):
    """ dataset preparation """
    if not opt.data_filtering_off:
        print('Filtering the images containing characters which are not in opt.character')
        print('Filtering the images whose label is longer than opt.batch_max_length')
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a', encoding='utf-16')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=opt.batch_size,
        shuffle=True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    elif 'Attn' in opt.Prediction:
        converter = AttnLabelConverter(opt.character)
    elif 'Transformer' in opt.Prediction or 'Test' in opt.Prediction or 'Transformer' in opt.SequenceModeling:
        converter = TransformerLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
          opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
          opt.SequenceModeling, opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    model = torch.nn.DataParallel(model).to(device)
    model.train()
    if opt.saved_model != '':
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_model), strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_model))
    print("Model:")
    print(model)

    """ setup loss """
    if 'CTC' in opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps)
    print("Optimizer:")
    print(optimizer)

    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a', encoding='utf-16') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)

    """ start training """
    start_iter = 0
    if opt.saved_model != '':
        try:
            start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    i = start_iter

    while(True):
        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        text, length = converter.encode(labels, batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)

        if 'CTC' in opt.Prediction:
            preds = model(image, text).log_softmax(2)
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            preds = preds.permute(1, 0, 2)

            # (ctc_a) For PyTorch 1.2.0 and 1.3.0. To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
            # https://github.com/jpuigcerver/PyLaia/issues/16
            torch.backends.cudnn.enabled = False
            cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device))
            torch.backends.cudnn.enabled = True

            # # (ctc_b) To reproduce our pretrained model / paper, use our previous code (below code) instead of (ctc_a).
            # # With PyTorch 1.2.0, the below code occurs NAN, so you may use PyTorch 1.1.0.
            # # Thus, the result of CTCLoss is different in PyTorch 1.1.0 and PyTorch 1.2.0.
            # # See https://github.com/clovaai/deep-text-recognition-benchmark/issues/56#issuecomment-526490707
            # cost = criterion(preds, text, preds_size, length)

        else:
            preds = model(image, text[:, :-1])  # align with Attention.forward
            target = text[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))

        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)

        # validation part
        if i % opt.valInterval == 0:
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, criterion, valid_loader, converter, opt)
                model.train()

                # training loss and validation loss
                loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth')
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth')
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":10s} | {"Prediction":10s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]):
                    if 'Attn' in opt.Prediction or 'Transformer' in opt.Prediction or 'Test' in opt.Prediction or 'Transformer' in opt.SequenceModeling:
                        gt = gt[:gt.find('[s]')]
                        pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:10s} | {pred:10s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                #log.write(predicted_result_log + '\n')

        # save model per 1e+5 iter.
        if (i + 1) % 1e+4 == 0:
            torch.save(
                model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')

        if i == opt.num_iter:
            print('end the training')
            sys.exit()
        i += 1
def train(opt):
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')

    #considering the real images for discriminator
    opt.batch_size = opt.batch_size * 2

    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(os.path.join(opt.exp_dir, opt.exp_name, 'log_dataset.txt'), 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3

    model = AdaINGen(opt)
    ocrModel = Model(opt)
    disModel = MsImageDis(opt)

    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)

    #  weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # Recognizer weight initialization
    for name, param in ocrModel.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # Discriminator weight initialization
    for name, param in disModel.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    ocrModel = torch.nn.DataParallel(ocrModel).to(device)
    ocrModel.train()

    model = torch.nn.DataParallel(model).to(device)
    model.train()

    disModel = torch.nn.DataParallel(disModel).to(device)
    disModel.train()

    #loading pre-trained model
    if opt.saved_ocr_model != '':
        print(f'loading pretrained ocr model from {opt.saved_ocr_model}')
        if opt.FT:
            ocrModel.load_state_dict(torch.load(opt.saved_ocr_model),
                                     strict=False)
        else:
            ocrModel.load_state_dict(torch.load(opt.saved_ocr_model))
    print("OCRModel:")
    print(ocrModel)

    if opt.saved_synth_model != '':
        print(f'loading pretrained synth model from {opt.saved_synth_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_synth_model),
                                  strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_synth_model))
    print("SynthModel:")
    print(model)

    if opt.saved_dis_model != '':
        print(
            f'loading pretrained discriminator model from {opt.saved_dis_model}'
        )
        if opt.FT:
            disModel.load_state_dict(torch.load(opt.saved_dis_model),
                                     strict=False)
        else:
            disModel.load_state_dict(torch.load(opt.saved_dis_model))
    print("DisModel:")
    print(disModel)
    """ setup loss """
    if 'CTC' in opt.Prediction:
        ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
            device)  # ignore [GO] token = ignore index 0

    recCriterion = torch.nn.L1Loss()
    styleRecCriterion = torch.nn.L1Loss()

    # loss averager
    loss_avg_ocr = Averager()
    loss_avg = Averager()
    loss_avg_dis = Averager()

    ##---------------------------------------##
    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("SynthOptimizer:")
    print(optimizer)

    #filter parameters for OCR training
    ocr_filtered_parameters = []
    ocr_params_num = []
    for p in filter(lambda p: p.requires_grad, ocrModel.parameters()):
        ocr_filtered_parameters.append(p)
        ocr_params_num.append(np.prod(p.size()))
    print('OCR Trainable params num : ', sum(ocr_params_num))

    # setup optimizer
    if opt.adam:
        ocr_optimizer = optim.Adam(ocr_filtered_parameters,
                                   lr=opt.lr,
                                   betas=(opt.beta1, 0.999))
    else:
        ocr_optimizer = optim.Adadelta(ocr_filtered_parameters,
                                       lr=opt.lr,
                                       rho=opt.rho,
                                       eps=opt.eps)
    print("OCROptimizer:")
    print(ocr_optimizer)

    #filter parameters for OCR training
    dis_filtered_parameters = []
    dis_params_num = []
    for p in filter(lambda p: p.requires_grad, disModel.parameters()):
        dis_filtered_parameters.append(p)
        dis_params_num.append(np.prod(p.size()))
    print('Dis Trainable params num : ', sum(dis_params_num))

    # setup optimizer
    if opt.adam:
        dis_optimizer = optim.Adam(dis_filtered_parameters,
                                   lr=opt.lr,
                                   betas=(opt.beta1, 0.999))
    else:
        dis_optimizer = optim.Adadelta(dis_filtered_parameters,
                                       lr=opt.lr,
                                       rho=opt.rho,
                                       eps=opt.eps)
    print("DisOptimizer:")
    print(dis_optimizer)
    ##---------------------------------------##
    """ final options """
    with open(os.path.join(opt.exp_dir, opt.exp_name, 'opt.txt'),
              'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0
    if opt.saved_synth_model != '':
        try:
            start_iter = int(
                opt.saved_synth_model.split('_')[-1].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    best_accuracy_ocr = -1
    best_norm_ED_ocr = -1
    iteration = start_iter
    # cntr=0
    while (True):
        # train part

        image_tensors_all, labels_1_all, labels_2_all = train_dataset.get_batch(
        )

        # ## comment
        # pdb.set_trace()
        # for imgCntr in range(image_tensors.shape[0]):
        #     save_image(tensor2im(image_tensors[imgCntr]),'temp/'+str(imgCntr)+'.png')
        # pdb.set_trace()
        # ###
        # print(cntr)
        # cntr+=1
        disCnt = int(image_tensors_all.size(0) / 2)
        image_tensors, image_tensors_real, labels_1, labels_2 = image_tensors_all[:disCnt], image_tensors_all[
            disCnt:disCnt +
            disCnt], labels_1_all[:disCnt], labels_2_all[:disCnt]

        image = image_tensors.to(device)
        image_real = image_tensors_real.to(device)
        text_1, length_1 = converter.encode(
            labels_1, batch_max_length=opt.batch_max_length)
        text_2, length_2 = converter.encode(
            labels_2, batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)

        images_recon_1, images_recon_2, style = model(image, text_1, text_2)

        if 'CTC' in opt.Prediction:

            #ocr training
            preds_ocr = ocrModel(image, text_1)
            preds_size_ocr = torch.IntTensor([preds_ocr.size(1)] * batch_size)
            preds_ocr = preds_ocr.log_softmax(2).permute(1, 0, 2)

            ocrCost_train = ocrCriterion(preds_ocr, text_1, preds_size_ocr,
                                         length_1)

            #dis training
            #Check: Using alternate real images
            disCost = opt.disWeight * 0.5 * (
                disModel.module.calc_dis_loss(images_recon_1.detach(),
                                              image_real) +
                disModel.module.calc_dis_loss(images_recon_2.detach(), image))

            #synth training
            preds_1 = ocrModel(images_recon_1, text_1)
            preds_size_1 = torch.IntTensor([preds_1.size(1)] * batch_size)
            preds_1 = preds_1.log_softmax(2).permute(1, 0, 2)

            preds_2 = ocrModel(images_recon_2, text_2)
            preds_size_2 = torch.IntTensor([preds_2.size(1)] * batch_size)
            preds_2 = preds_2.log_softmax(2).permute(1, 0, 2)

            ocrCost = 0.5 * (
                ocrCriterion(preds_1, text_1, preds_size_1, length_1) +
                ocrCriterion(preds_2, text_2, preds_size_2, length_2))

            #gen training
            disGenCost = 0.5 * (disModel.module.calc_gen_loss(images_recon_1) +
                                disModel.module.calc_gen_loss(images_recon_2))

        else:
            preds = model(image, text[:, :-1])  # align with Attention.forward
            target = text[:, 1:]  # without [GO] Symbol
            ocrCost = ocrCriterion(preds.view(-1, preds.shape[-1]),
                                   target.contiguous().view(-1))

        recCost = recCriterion(images_recon_1, image)
        styleRecCost = styleRecCriterion(
            model(images_recon_2, None, None, styleFlag=True), style.detach())

        cost = opt.ocrWeight * ocrCost + opt.reconWeight * recCost + opt.disWeight * disGenCost + opt.styleReconWeight * styleRecCost

        disModel.zero_grad()
        disCost.backward()
        torch.nn.utils.clip_grad_norm_(
            disModel.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        dis_optimizer.step()

        loss_avg_dis.add(disCost)

        model.zero_grad()
        ocrModel.zero_grad()
        disModel.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)

        #training OCR
        ocrModel.zero_grad()
        ocrCost_train.backward()
        torch.nn.utils.clip_grad_norm_(
            ocrModel.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        ocr_optimizer.step()

        loss_avg_ocr.add(ocrCost_train)

        #START HERE
        # validation part

        if (
                iteration + 1
        ) % opt.valInterval == 0 or iteration == 0:  # To see training progress, we also conduct validation when 'iteration == 0'

            #Save training images
            os.makedirs(os.path.join(opt.exp_dir, opt.exp_name, 'trainImages',
                                     str(iteration)),
                        exist_ok=True)
            for trImgCntr in range(batch_size):
                try:

                    save_image(
                        tensor2im(image[trImgCntr].detach()),
                        os.path.join(
                            opt.exp_dir, opt.exp_name, 'trainImages',
                            str(iteration),
                            str(trImgCntr) + '_input_' + labels_1[trImgCntr] +
                            '.png'))
                    save_image(
                        tensor2im(images_recon_1[trImgCntr].detach()),
                        os.path.join(
                            opt.exp_dir, opt.exp_name, 'trainImages',
                            str(iteration),
                            str(trImgCntr) + '_recon_' + labels_1[trImgCntr] +
                            '.png'))
                    save_image(
                        tensor2im(images_recon_2[trImgCntr].detach()),
                        os.path.join(
                            opt.exp_dir, opt.exp_name, 'trainImages',
                            str(iteration),
                            str(trImgCntr) + '_pair_' + labels_2[trImgCntr] +
                            '.png'))
                except:
                    print('Warning while saving training image')

            elapsed_time = time.time() - start_time
            # for log

            with open(os.path.join(opt.exp_dir, opt.exp_name, 'log_train.txt'),
                      'a') as log:
                model.eval()
                ocrModel.eval()
                disModel.eval()
                with torch.no_grad():
                    # valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                    #     model, criterion, valid_loader, converter, opt)

                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation_synth_lrw(
                        iteration, model, ocrModel, disModel, recCriterion,
                        styleRecCriterion, ocrCriterion, valid_loader,
                        converter, opt)
                model.train()
                ocrModel.train()
                disModel.train()

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}] Train OCR loss: {loss_avg_ocr.val():0.5f}, Train Synth loss: {loss_avg.val():0.5f}, Train Dis loss: {loss_avg_dis.val():0.5f}, Valid OCR loss: {valid_loss[0]:0.5f}, Valid Synth loss: {valid_loss[1]:0.5f}, Valid Dis loss: {valid_loss[2]:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                loss_avg_ocr.reset()
                loss_avg.reset()
                loss_avg_dis.reset()

                current_model_log_ocr = f'{"Current_accuracy_OCR":17s}: {current_accuracy[0]:0.3f}, {"Current_norm_ED_OCR":17s}: {current_norm_ED[0]:0.2f}'
                current_model_log_1 = f'{"Current_accuracy_recon":17s}: {current_accuracy[1]:0.3f}, {"Current_norm_ED_recon":17s}: {current_norm_ED[1]:0.2f}'
                current_model_log_2 = f'{"Current_accuracy_pair":17s}: {current_accuracy[2]:0.3f}, {"Current_norm_ED_pair":17s}: {current_norm_ED[2]:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy[1] > best_accuracy:
                    best_accuracy = current_accuracy[1]
                    torch.save(
                        model.state_dict(),
                        os.path.join(opt.exp_dir, opt.exp_name,
                                     'best_accuracy.pth'))
                    torch.save(
                        disModel.state_dict(),
                        os.path.join(opt.exp_dir, opt.exp_name,
                                     'best_accuracy_dis.pth'))
                if current_norm_ED[1] > best_norm_ED:
                    best_norm_ED = current_norm_ED[1]
                    torch.save(
                        model.state_dict(),
                        os.path.join(opt.exp_dir, opt.exp_name,
                                     'best_norm_ED.pth'))
                    torch.save(
                        disModel.state_dict(),
                        os.path.join(opt.exp_dir, opt.exp_name,
                                     'best_norm_ED_dis.pth'))
                best_model_log = f'{"Best_accuracy_Recon":17s}: {best_accuracy:0.3f}, {"Best_norm_ED_Recon":17s}: {best_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy[0] > best_accuracy_ocr:
                    best_accuracy_ocr = current_accuracy[0]
                    torch.save(
                        ocrModel.state_dict(),
                        os.path.join(opt.exp_dir, opt.exp_name,
                                     'best_accuracy_ocr.pth'))
                if current_norm_ED[0] > best_norm_ED_ocr:
                    best_norm_ED_ocr = current_norm_ED[0]
                    torch.save(
                        ocrModel.state_dict(),
                        os.path.join(opt.exp_dir, opt.exp_name,
                                     'best_norm_ED_ocr.pth'))
                best_model_log_ocr = f'{"Best_accuracy_ocr":17s}: {best_accuracy_ocr:0.3f}, {"Best_norm_ED_ocr":17s}: {best_norm_ED_ocr:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log_ocr}\n{current_model_log_1}\n{current_model_log_2}\n{best_model_log_ocr}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":32s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                for gt_ocr, pred_ocr, confidence_ocr, gt_1, pred_1, confidence_1, gt_2, pred_2, confidence_2 in zip(
                        labels[0][:5], preds[0][:5], confidence_score[0][:5],
                        labels[1][:5], preds[1][:5], confidence_score[1][:5],
                        labels[2][:5], preds[2][:5], confidence_score[2][:5]):
                    if 'Attn' in opt.Prediction:
                        gt = gt[:gt.find('[s]')]
                        pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{"ocr"}: {gt_ocr:27s} | {pred_ocr:25s} | {confidence_ocr:0.4f}\t{str(pred_ocr == gt_ocr)}\n'
                    predicted_result_log += f'{"recon"}: {gt_1:25s} | {pred_1:25s} | {confidence_1:0.4f}\t{str(pred_1 == gt_1)}\n'
                    predicted_result_log += f'{"pair"}: {gt_2:26s} | {pred_2:25s} | {confidence_2:0.4f}\t{str(pred_2 == gt_2)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')

        # save model per 1e+5 iter.
        if (iteration + 1) % 1e+5 == 0:
            torch.save(
                model.state_dict(),
                os.path.join(opt.exp_dir, opt.exp_name,
                             'iter_{iteration+1}.pth'))
            torch.save(
                ocrModel.state_dict(),
                os.path.join(opt.exp_dir, opt.exp_name,
                             'iter_{iteration+1}_ocr.pth'))
            torch.save(
                disModel.state_dict(),
                os.path.join(opt.exp_dir, opt.exp_name,
                             'iter_{iteration+1}_dis.pth'))

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
Exemple #9
0
def train(opt):
    os.makedirs(opt.log, exist_ok=True)
    writer = SummaryWriter(opt.log)
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)

    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    """ model configuration """

    ctc_converter = CTCLabelConverter(opt.character)
    attn_converter = AttnLabelConverter(opt.character)
    opt.num_class = len(attn_converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)

    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    model = torch.nn.DataParallel(model).to(device)
    model.train()
    if opt.saved_model != '':
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_model), strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_model))
    """ setup loss """
    loss_avg = Averager()
    ctc_loss = torch.nn.CTCLoss(zero_infinity=True).to(device)
    attn_loss = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0
    if opt.saved_model != '':
        try:
            start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    iteration = start_iter
    pbar = tqdm(range(opt.num_iter))

    for iteration in pbar:

        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        ctc_text, ctc_length = ctc_converter.encode(
            labels, batch_max_length=opt.batch_max_length)
        attn_text, attn_length = attn_converter.encode(
            labels, batch_max_length=opt.batch_max_length)

        batch_size = image.size(0)

        preds, refiner = model(image, attn_text[:, :-1])

        refiner_size = torch.IntTensor([refiner.size(1)] * batch_size)
        refiner = refiner.log_softmax(2).permute(1, 0, 2)
        refiner_loss = ctc_loss(refiner, ctc_text, refiner_size, ctc_length)

        total_loss = opt.lambda_ctc * refiner_loss
        target = attn_text[:, 1:]  # without [GO] Symbol
        for pred in preds:
            total_loss += opt.lambda_attn * attn_loss(
                pred.view(-1, pred.shape[-1]),
                target.contiguous().view(-1))

        model.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()
        loss_avg.add(total_loss)
        if loss_avg.val() <= 0.6:
            opt.grad_clip = 2
        if loss_avg.val() <= 0.3:
            opt.grad_clip = 1

        preds = (p.cpu() for p in preds)
        refiner = refiner.cpu()
        image = image.cpu()
        torch.cuda.empty_cache()

        writer.add_scalar('train_loss', loss_avg.val(), iteration)
        pbar.set_description('Iteration {0}/{1}, AvgLoss {2}'.format(
            iteration, opt.num_iter, loss_avg.val()))

        # validation part
        if (iteration + 1) % opt.valInterval == 0 or iteration == 0:
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.exp_name}/log_train.txt',
                      'a') as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, attn_loss, valid_loader, attn_converter, opt)
                model.train()

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                writer.add_scalar('Val_loss', valid_loss)
                pbar.set_description(loss_log)
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.exp_name}/best_accuracy_{str(best_accuracy)}.pth'
                    )
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.exp_name}/best_norm_ED.pth')
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                # print(loss_model_log)

                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                for gt, pred, confidence in zip(labels[:5], preds[:5],
                                                confidence_score[:5]):
                    if 'Attn' or 'Transformer' in opt.Prediction:
                        gt = gt[:gt.find('[s]')]
                        pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                log.write(predicted_result_log + '\n')

        # save model per 1e+3 iter.
        if (iteration + 1) % 1e+3 == 0:
            torch.save(model.state_dict(),
                       f'./saved_models/{opt.exp_name}/SCATTER_STR.pth')

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
Exemple #10
0
def train(opt):
    """ 准备训练和验证的数据集 """
    transform = transforms.Compose([
        ToTensor(),
    ])
    train_dataset = LmdbDataset(opt.train_data, opt=opt, transform=transform)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=opt.batch_size,
        shuffle=True,
        num_workers=int(opt.workers),
    )

    valid_dataset = LmdbDataset(root=opt.valid_data,
                                opt=opt,
                                transform=transform)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=True,
        num_workers=int(opt.workers),
    )
    print('-' * 80)
    """ 模型的配置 """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)

    # 权重初始化
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    model = model.to(device)
    model.train()
    if opt.continue_model != '':
        print(f'loading pretrained model from {opt.continue_model}')
        model.load_state_dict(torch.load(opt.continue_model))
    print("Model:")
    print(model)
    """ setup loss """
    if 'CTC' in opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
            device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    print(optimizer)
    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.experiment_name}/opt.txt',
              'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0
    if opt.continue_model != '':
        start_iter = int(opt.continue_model.split('_')[-1].split('.')[0])
        print(f'continue to train, start_iter: {start_iter}')

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = 1e+6
    i = start_iter

    while (True):
        # train part
        for image_tensors, labels in train_loader:
            image = image_tensors.to(device)
            text, length = converter.encode(
                labels, batch_max_length=opt.batch_max_length
            )  # text: [index, index, ..., index], length: [10, 8]
            batch_size = image.size(0)

            if 'CTC' in opt.Prediction:
                # set xx = model(image, text) torch.Size([100, 63, 7]), xx.log_softmax(2)[0][0] = xx[0][0].log_softmax(-1)
                preds = model(image,
                              text).log_softmax(2)  # torch.Size([100, 63, 12])
                preds_size = torch.IntTensor([preds.size(1)] *
                                             batch_size).to(device)
                preds = preds.permute(
                    1, 0, 2
                )  # to use CTCLoss format  # 100 * 63 * 7 ->  63 * 100 * 7

                # To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
                # https://github.com/jpuigcerver/PyLaia/issues/16
                torch.backends.cudnn.enabled = False
                cost = criterion(
                    preds, text, preds_size, length
                )  # preds.shape: torch.Size([63, 100, 7]), 其中63是序列特征,100是batch_size, 7是输出类别数量; text.shape: torch.Size([1000]), 表示1000个字符
                # preds_size:[63, 63, ..., 63] 100,数组中的63表示序列的长度 length: [10, 10, ..., 10] 100,数组中的每个10表示每个标签的长度,意思就是每一张图片有10个字符
                torch.backends.cudnn.enabled = True

            else:
                preds = model(image,
                              text[:, :-1])  # align with Attention.forward
                target = text[:, 1:]  # without [GO] Symbol
                cost = criterion(preds.view(-1, preds.shape[-1]),
                                 target.contiguous().view(-1))

            model.zero_grad()
            cost.backward()
            torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                opt.grad_clip)  # gradient clipping with 5 (Default)
            optimizer.step()

            loss_avg.add(cost)

            # validation part
            if i % opt.valInterval == 0:
                elapsed_time = time.time() - start_time
                print(
                    f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}'
                )
                # for log
                with open(
                        f'./saved_models/{opt.experiment_name}/log_train.txt',
                        'a') as log:
                    log.write(
                        f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n'
                    )
                    loss_avg.reset()

                    model.eval()
                    with torch.no_grad():
                        valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation(
                            model, criterion, valid_loader, converter, opt)
                    model.train()

                    for pred, gt in zip(preds[:5], labels[:5]):
                        if 'Attn' in opt.Prediction:
                            pred = pred[:pred.find('[s]')]
                            gt = gt[:gt.find('[s]')]
                        print(f'{pred:20s}, gt: {gt:20s},   {str(pred == gt)}')
                        log.write(
                            f'{pred:20s}, gt: {gt:20s},   {str(pred == gt)}\n')

                    valid_log = f'[{i}/{opt.num_iter}] valid loss: {valid_loss:0.5f}'
                    valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}'
                    print(valid_log)
                    log.write(valid_log + '\n')

                    # keep best accuracy model
                    if current_accuracy > best_accuracy:
                        best_accuracy = current_accuracy
                        torch.save(
                            model.state_dict(),
                            f'./saved_models/{opt.experiment_name}/best_accuracy.pth'
                        )
                    if current_norm_ED < best_norm_ED:
                        best_norm_ED = current_norm_ED
                        torch.save(
                            model.state_dict(),
                            f'./saved_models/{opt.experiment_name}/best_norm_ED.pth'
                        )
                    best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}'
                    print(best_model_log)
                    log.write(best_model_log + '\n')

            # save model per 1e+5 iter.
            if (i + 1) % 1e+5 == 0:
                torch.save(
                    model.state_dict(),
                    f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')

            if i == opt.num_iter:
                print('end the training')
                sys.exit()
            i += 1
Exemple #11
0
def train(opt,
          AMP,
          WdB,
          ralph_path,
          train_data_path,
          train_data_list,
          test_data_path,
          test_data_list,
          experiment_name,
          train_batch_size,
          val_batch_size,
          workers,
          lr,
          valInterval,
          num_iter,
          wdbprj,
          continue_model='',
          finetune=''):

    HVD3P = pO.HVD or pO.DDP

    os.makedirs(f'./saved_models/{experiment_name}', exist_ok=True)

    # if OnceExecWorker and WdB:
    #     wandb.init(project=wdbprj, name=experiment_name)
    #     wandb.config.update(opt)

    # load supplied ralph
    with open(ralph_path, 'r') as f:
        ralph_train = json.load(f)

    print('[4] IN TRAIN; BEFORE MAKING DATASET')
    train_dataset = ds_load.myLoadDS(train_data_list,
                                     train_data_path,
                                     ralph=ralph_train)
    valid_dataset = ds_load.myLoadDS(test_data_list,
                                     test_data_path,
                                     ralph=ralph_train)

    # SAVE RALPH FOR LATER USE
    # with open(f'./saved_models/{experiment_name}/ralph.json', 'w+') as f:
    #     json.dump(train_dataset.ralph, f)

    print('[5] DATASET DONE LOADING')
    if OnceExecWorker:
        print(pO)
        print('Alphabet :', len(train_dataset.alph), train_dataset.alph)
        for d in [train_dataset, valid_dataset]:
            print('Dataset Size :', len(d.fns))
            # print('Max LbW : ',max(list(map(len,d.tlbls))) )
            # print('#Chars : ',sum([len(x) for x in d.tlbls]))
            # print('Sample label :',d.tlbls[-1])
            # print("Dataset :", sorted(list(map(len,d.tlbls))) )
            print('-' * 80)

    if opt.num_gpu > 1:
        workers = workers * (1 if HVD3P else opt.num_gpu)

    if HVD3P:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset, num_replicas=opt.world_size, rank=opt.rank)
        valid_sampler = torch.utils.data.distributed.DistributedSampler(
            valid_dataset, num_replicas=opt.world_size, rank=opt.rank)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=train_batch_size,
        shuffle=True if not HVD3P else False,
        pin_memory=True,
        num_workers=int(workers),
        sampler=train_sampler if HVD3P else None,
        worker_init_fn=WrkSeeder,
        collate_fn=ds_load.SameTrCollate)

    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=val_batch_size,
        pin_memory=True,
        num_workers=int(workers),
        sampler=valid_sampler if HVD3P else None)

    model = OrigamiNet()
    model.apply(init_bn)

    # load finetune ckpt
    if finetune != '':
        model = load_finetune(model, finetune)

    model.train()

    if OnceExecWorker:
        import pprint
        [print(k, model.lreszs[k]) for k in sorted(model.lreszs.keys())]

    biparams = list(
        dict(filter(lambda kv: 'bias' in kv[0],
                    model.named_parameters())).values())
    nonbiparams = list(
        dict(filter(lambda kv: 'bias' not in kv[0],
                    model.named_parameters())).values())

    if not pO.DDP:
        model = model.to(device)
    else:
        model.cuda(opt.rank)

    optimizer = optim.Adam(model.parameters(), lr=lr)
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                          gamma=10**(-1 /
                                                                     90000))

    # if OnceExecWorker and WdB:
    #     wandb.watch(model, log="all")

    if pO.HVD:
        hvd.broadcast_parameters(model.state_dict(), root_rank=0)
        hvd.broadcast_optimizer_state(optimizer, root_rank=0)
        optimizer = hvd.DistributedOptimizer(
            optimizer, named_parameters=model.named_parameters())
        # optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters(), compression=hvd.Compression.fp16)

    if pO.DDP and opt.rank != 0:
        random.seed()
        np.random.seed()

    # if AMP:
    #     model, optimizer = amp.initialize(model, optimizer, opt_level = "O1")
    if pO.DP:
        model = torch.nn.DataParallel(model)
    elif pO.DDP:
        model = pDDP(model,
                     device_ids=[opt.rank],
                     output_device=opt.rank,
                     find_unused_parameters=False)

    model_ema = ModelEma(model)

    if continue_model != '':
        if OnceExecWorker:
            print(f'loading pretrained model from {continue_model}')
        checkpoint = torch.load(
            continue_model, map_location=f'cuda:{opt.rank}' if HVD3P else None)
        model.load_state_dict(checkpoint['model'], strict=True)
        optimizer.load_state_dict(checkpoint['optimizer'])
        model_ema._load_checkpoint(continue_model,
                                   f'cuda:{opt.rank}' if HVD3P else None)

    criterion = torch.nn.CTCLoss(reduction='none',
                                 zero_infinity=True).to(device)
    converter = CTCLabelConverter(train_dataset.ralph.values())

    if OnceExecWorker:
        with open(f'./saved_models/{experiment_name}/opt.txt',
                  'a') as opt_file:
            opt_log = '------------ Options -------------\n'
            args = vars(opt)
            for k, v in args.items():
                opt_log += f'{str(k)}: {str(v)}\n'
            opt_log += '---------------------------------------\n'
            opt_log += gin.operative_config_str()
            opt_file.write(opt_log)
            # if WdB:
            #     wandb.config.gin_str = gin.operative_config_str().splitlines()

        print(optimizer)
        print(opt_log)

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = 1e+6
    best_CER = 1e+6
    i = 0
    gAcc = 1
    epoch = 1
    btReplay = False and AMP
    max_batch_replays = 3

    if HVD3P: train_sampler.set_epoch(epoch)
    titer = iter(train_loader)

    while (True):
        start_time = time.time()

        model.zero_grad()
        train_loss = Metric(pO, 'train_loss')

        for j in trange(valInterval, leave=False, desc='Training'):

            # Load a batch
            try:
                image_tensors, labels, fnames = next(titer)
            except StopIteration:
                epoch += 1
                if HVD3P: train_sampler.set_epoch(epoch)
                titer = iter(train_loader)
                image_tensors, labels, fnames = next(titer)

            # log filenames
            # fnames = [f'{i}___{fname}' for fname in fnames]
            # with open(f'./saved_models/{experiment_name}/filelog.txt', 'a+') as f:
            #     f.write('\n'.join(fnames) + '\n')

            # Move to device
            image = image_tensors.to(device)
            text, length = converter.encode(labels)
            batch_size = image.size(0)

            replay_batch = True
            maxR = 3
            while replay_batch and maxR > 0:
                maxR -= 1

                # Forward pass
                preds = model(image, text).float()
                preds_size = torch.IntTensor([preds.size(1)] *
                                             batch_size).to(device)
                preds = preds.permute(1, 0, 2).log_softmax(2)

                if i == 0 and OnceExecWorker:
                    print('Model inp : ', image.dtype, image.size())
                    print('CTC inp : ', preds.dtype, preds.size(),
                          preds_size[0])

                # To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
                torch.backends.cudnn.enabled = False
                cost = criterion(preds, text.to(device), preds_size,
                                 length.to(device)).mean() / gAcc
                torch.backends.cudnn.enabled = True

                train_loss.update(cost)

                # cost tracking?
                # with open(f'./saved_models/{experiment_name}/steplog.txt', 'a+') as f:
                #     f.write(f'Step {i} cost: {cost}\n')

                optimizer.zero_grad()
                default_optimizer_step = optimizer.step  # added for batch replay

                # Backward and step
                if not AMP:
                    cost.backward()
                    replay_batch = False
                else:
                    # with amp.scale_loss(cost, optimizer) as scaled_loss:
                    #     scaled_loss.backward()
                    #     if pO.HVD: optimizer.synchronize()

                    # if optimizer.step is default_optimizer_step or not btReplay:
                    #     replay_batch = False
                    # elif maxR>0:
                    #     optimizer.step()
                    pass

            if btReplay:
                pass  #amp._amp_state.loss_scalers[0]._loss_scale = mx_sc

            if (i + 1) % gAcc == 0:

                if pO.HVD and AMP:
                    with optimizer.skip_synchronize():
                        optimizer.step()
                else:
                    optimizer.step()

                model.zero_grad()
                model_ema.update(model, num_updates=i / 2)

                if (i + 1) % (gAcc * 2) == 0:
                    lr_scheduler.step()

            i += 1

        # validation part
        if True:

            elapsed_time = time.time() - start_time
            start_time = time.time()

            model.eval()
            with torch.no_grad():

                valid_loss, current_accuracy, current_norm_ED, ted, bleu, preds, labels, infer_time = validation(
                    model_ema.ema, criterion, valid_loader, converter, opt, pO)

            model.train()
            v_time = time.time() - start_time

            if OnceExecWorker:
                if current_norm_ED < best_norm_ED:
                    best_norm_ED = current_norm_ED
                    checkpoint = {
                        'model': model.state_dict(),
                        'state_dict_ema': model_ema.ema.state_dict(),
                        'optimizer': optimizer.state_dict(),
                    }
                    torch.save(
                        checkpoint,
                        f'./saved_models/{experiment_name}/best_norm_ED.pth')

                if ted < best_CER:
                    best_CER = ted

                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy

                out = f'[{i}] Loss: {train_loss.avg:0.5f} time: ({elapsed_time:0.1f},{v_time:0.1f})'
                out += f' vloss: {valid_loss:0.3f}'
                out += f' CER: {ted:0.4f} NER: {current_norm_ED:0.4f} lr: {lr_scheduler.get_lr()[0]:0.5f}'
                out += f' bAcc: {best_accuracy:0.1f}, bNER: {best_norm_ED:0.4f}, bCER: {best_CER:0.4f}, B: {bleu*100:0.2f}'
                print(out)

                with open(f'./saved_models/{experiment_name}/log_train.txt',
                          'a') as log:
                    log.write(out + '\n')

                # if WdB:
                #     wandb.log({'lr': lr_scheduler.get_lr()[0], 'It':i, 'nED': current_norm_ED,  'B':bleu*100,
                #     'tloss':train_loss.avg, 'AnED': best_norm_ED, 'CER':ted, 'bestCER':best_CER, 'vloss':valid_loss})

        if DEBUG:
            print(
                f'[!!!] Iteration check. Value of i: {i} | Value of num_iter: {num_iter}'
            )

        # Change i == num_iter to i >= num_iter
        # Add num_iter > 0 condition
        if num_iter > 0 and i >= num_iter:
            print('end the training')
            #sys.exit()
            break
Exemple #12
0
def train(opt):
    """ dataset preparation """
    if opt.select_data == 'baidu':
        train_set = BAIDUset(opt, opt.train_csv)
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=opt.batch_size,
            shuffle=True,
            num_workers=int(opt.workers),
            collate_fn=BaiduCollate(opt.imgH, opt.imgW, keep_ratio=False))
        val_set = BAIDUset(opt, opt.val_csv)
        valid_loader = torch.utils.data.DataLoader(
            val_set,
            batch_size=opt.batch_size,
            shuffle=True,
            num_workers=int(opt.workers),
            collate_fn=BaiduCollate(opt.imgH, opt.imgW, keep_ratio=False),
            pin_memory=True)

    else:
        opt.select_data = opt.select_data.split('-')
        opt.batch_ratio = opt.batch_ratio.split('-')
        train_dataset = Batch_Balanced_Dataset(opt)

        AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                          imgW=opt.imgW,
                                          keep_ratio_with_pad=opt.PAD)
        valid_dataset = hierarchical_dataset(root=opt.valid_data, opt=opt)
        valid_loader = torch.utils.data.DataLoader(
            valid_dataset,
            batch_size=opt.batch_size,
            shuffle=
            True,  # 'True' to check training progress with validation function.
            num_workers=int(opt.workers),
            collate_fn=AlignCollate_valid,
            pin_memory=True)
    print('-' * 80)
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    elif 'Bert' in opt.Prediction:
        converter = TransformerConverter(opt.character, opt.max_seq)
    elif 'SRN' in opt.Prediction:
        converter = SRNConverter(opt.character, opt.SRN_PAD)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    model = torch.nn.DataParallel(model).cuda()
    model.train()
    if opt.continue_model != '':
        print(f'loading pretrained model from {opt.continue_model}')
        model.load_state_dict(torch.load(opt.continue_model))
    print("Model:")
    print(model)
    """ setup loss """
    if 'CTC' in opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).cuda()
    elif 'Bert' in opt.Prediction:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).cuda()
    elif 'SRN' in opt.Prediction:
        criterion = cal_performance
    else:
        criterion = torch.nn.CrossEntropyLoss(
            ignore_index=0).cuda()  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    elif opt.ranger:
        optimizer = Ranger(filtered_parameters, lr=opt.lr)
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    print(optimizer)

    lrScheduler = lr_scheduler.MultiStepLR(optimizer, [5, 20, 30],
                                           gamma=0.5)  # 减小学习速率
    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.experiment_name}/opt.txt',
              'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0
    if opt.continue_model != '':
        start_iter = int(opt.continue_model.split('_')[-1].split('.')[0])
        print(f'continue to train, start_iter: {start_iter}')

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = 1e+6
    i = start_iter
    if opt.select_data == 'baidu':
        train_iter = iter(train_loader)
        step_per_epoch = len(train_set) / opt.batch_size
        print('一代有多少step:', step_per_epoch)
    else:
        step_per_epoch = train_dataset.nums_samples / opt.batch_size
        print('一代有多少step:', step_per_epoch)

    while (True):
        # try:
        # train part
        for p in model.parameters():
            p.requires_grad = True

        if opt.select_data == 'baidu':
            try:
                image_tensors, labels = train_iter.next()
            except:
                train_iter = iter(train_loader)
                image_tensors, labels = train_iter.next()
        else:
            image_tensors, labels = train_dataset.get_batch()

        image = image_tensors.cuda()
        text, length = converter.encode(labels)
        batch_size = image.size(0)

        if 'CTC' in opt.Prediction:
            preds = model(image, text).log_softmax(2)
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            preds = preds.permute(1, 0, 2)  # to use CTCLoss format
            cost = criterion(preds, text, preds_size, length)

        elif 'Bert' in opt.Prediction:
            pad_mask = None
            # print(image.shape)
            preds = model(image, pad_mask)
            cost = criterion(preds[0].view(-1, preds[0].shape[-1]), text.contiguous().view(-1)) + \
                   criterion(preds[1].view(-1, preds[1].shape[-1]), text.contiguous().view(-1))

        elif 'SRN' in opt.Prediction:
            preds = model(image, None)
            cost, n_correct = criterion(preds, text)

        else:
            preds = model(image, text[:, :-1])  # align with Attention.forward
            target = text[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.view(-1, preds.shape[-1]),
                             target.contiguous().view(-1))

        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)

        if i % opt.disInterval == 0:
            elapsed_time = time.time() - start_time
            print(
                f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}'
            )
            start_time = time.time()

        # validation part
        if i % opt.valInterval == 0 and i > start_iter:
            elapsed_time = time.time() - start_time
            print(
                f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}'
            )
            # for log
            with open(f'./saved_models/{opt.experiment_name}/log_train.txt',
                      'a') as log:
                log.write(
                    f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n'
                )
                loss_avg.reset()

                model.eval()
                valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation(
                    model, criterion, valid_loader, converter, opt)
                model.train()

                for pred, gt in zip(preds[:5], labels[:5]):
                    if 'Attn' in opt.Prediction:
                        pred = pred[:pred.find('[s]')]
                        gt = gt[:gt.find('[s]')]
                    print(
                        f'pred: {pred:20s}, gt: {gt:20s},   {str(pred == gt)}')
                    log.write(
                        f'pred: {pred:20s}, gt: {gt:20s},   {str(pred == gt)}\n'
                    )

                valid_log = f'[{i}/{opt.num_iter}] valid loss: {valid_loss:0.5f}'
                valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}'
                print(valid_log)
                log.write(valid_log + '\n')

                # keep best accuracy model
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.experiment_name}/best_accuracy.pth'
                    )
                if current_norm_ED < best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.experiment_name}/best_norm_ED.pth'
                    )
                best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}'
                print(best_model_log)
                log.write(best_model_log + '\n')

        # save model per 1e+5 iter.
        if (i + 1) % opt.saveInterval == 0:
            torch.save(model.state_dict(),
                       f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')

        if i == opt.num_iter:
            print('end the training')
            sys.exit()

        if i > 0 and i % step_per_epoch == 0:  # 调整学习速率
            lrScheduler.step()

        i += 1
Exemple #13
0
def train(opt):
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    """ model configuration """
    # CTCLoss
    converter_ctc = CTCLabelConverter(opt.character)
    # Attention
    converter_atten = AttnLabelConverter(opt.character)
    opt.num_class_ctc = len(converter_ctc.character)
    opt.num_class_atten = len(converter_atten.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class_ctc, opt.num_class_atten, opt.batch_max_length,
          opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling,
          opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p_: p_.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    print(optimizer)

    # use fp16 to train
    model = model.to(device)
    if opt.fp16:
        with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log:
            log.write('==> Enable fp16 training' + '\n')
        print('==> Enable fp16 training')
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

    # data parallel for multi-GPU
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model).to(device)
    model.train()
    # for i in model.module.Prediction_atten:
    #     i.to(device)
    # for i in model.module.Feat_Extraction.scr:
    #     i.to(device)
    if opt.saved_model != '':
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_model), strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_model))
    print("Model:")
    print(model)
    """ setup loss """
    criterion_ctc = torch.nn.CTCLoss(zero_infinity=True).to(device)
    criterion_atten = torch.nn.CrossEntropyLoss(ignore_index=0).to(
        device)  # ignore [GO] token = ignore index 0

    # loss averager
    loss_avg = Averager()
    """ final options """
    writer = SummaryWriter(f'./saved_models/{opt.exp_name}')
    # print(opt)
    with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0
    if opt.saved_model != '':
        try:
            start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    iteration = start_iter

    # image_tensors, labels = train_dataset.get_batch()
    while True:
        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        batch_size = image.size(0)
        text_ctc, length_ctc = converter_ctc.encode(
            labels, batch_max_length=opt.batch_max_length)
        text_atten, length_atten = converter_atten.encode(
            labels, batch_max_length=opt.batch_max_length)

        # type tuple; (tensor, list);         text_atten[:, :-1]:align with Attention.forward
        preds_ctc, preds_atten = model(image, text_atten[:, :-1])
        # CTC Loss
        preds_size = torch.IntTensor([preds_ctc.size(1)] * batch_size)
        # _, preds_index = preds_ctc.max(2)
        # preds_str_ctc = converter_ctc.decode(preds_index.data, preds_size.data)
        preds_ctc = preds_ctc.log_softmax(2).permute(1, 0, 2)
        cost_ctc = 0.1 * criterion_ctc(preds_ctc, text_ctc, preds_size,
                                       length_ctc)

        # Attention Loss
        # preds_atten = [i[:, :text_atten.shape[1] - 1, :] for i in preds_atten]
        # # select max probabilty (greedy decoding) then decode index to character
        # preds_index_atten = [i.max(2)[1] for i in preds_atten]
        # length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
        # preds_str_atten = [converter_atten.decode(i, length_for_pred) for i in preds_index_atten]
        # preds_str_atten2 = preds_str_atten
        # preds_str_atten = []
        # for i in preds_str_atten2:  # prune after "end of sentence" token ([s])
        #     temp = []
        #     for j in i:
        #         j = j[:j.find('[s]')]
        #         temp.append(j)
        #     preds_str_atten.append(temp)
        # preds_str_atten = [j[:j.find('[s]')] for i in preds_str_atten for j in i]
        target = text_atten[:, 1:]  # without [GO] Symbol
        # cost_atten = 1.0 * criterion_atten(preds_atten.view(-1, preds_atten.shape[-1]), target.contiguous().view(-1))
        for index, pred in enumerate(preds_atten):
            if index == 0:
                cost_atten = 1.0 * criterion_atten(
                    pred.view(-1, pred.shape[-1]),
                    target.contiguous().view(-1))
            else:
                cost_atten += 1.0 * criterion_atten(
                    pred.view(-1, pred.shape[-1]),
                    target.contiguous().view(-1))
        # cost_atten = [1.0 * criterion_atten(pred.view(-1, pred.shape[-1]), target.contiguous().view(-1)) for pred in
        #               preds_atten]
        # cost_atten = criterion_atten(preds_atten.view(-1, preds_atten.shape[-1]), target.contiguous().view(-1))
        cost = cost_ctc + cost_atten
        writer.add_scalar('loss', cost.item(), global_step=iteration + 1)

        # cost = cost_ctc
        # cost = cost_atten
        if (iteration + 1) % 100 == 0:
            print('\riter: {:4d}\tloss: {:6.3f}\tavg: {:6.3f}'.format(
                iteration + 1, cost.item(), loss_avg.val()),
                  end='\n')
        else:
            print('\riter: {:4d}\tloss: {:6.3f}\tavg: {:6.3f}'.format(
                iteration + 1, cost.item(), loss_avg.val()),
                  end='')
        sys.stdout.flush()
        if cost < 0.001:
            print(f'iter: {iteration + 1}\tloss: {cost}')
            # aaaaaa = 0

        # model.zero_grad()
        optimizer.zero_grad()
        if torch.isnan(cost):
            print(f'iter: {iteration + 1}\tloss: {cost}\t==> Loss is NAN')
            sys.exit()
        elif torch.isinf(cost):
            print(f'iter: {iteration + 1}\tloss: {cost}\t==> Loss is INF')
            sys.exit()
        else:
            if opt.fp16:
                with amp.scale_loss(cost, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                cost.backward()
            torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)
        writer.add_scalar('loss_avg',
                          loss_avg.val(),
                          global_step=iteration + 1)
        # if loss_avg.val() <= 0.6:
        #     opt.grad_clip = 2
        # if loss_avg.val() <= 0.3:
        #     opt.grad_clip = 1

        # validation part
        if iteration == 0 or (
                iteration + 1
        ) % opt.valInterval == 0:  # To see training progress, we also conduct validation when 'iteration == 0'
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.exp_name}/log_train.txt',
                      'a') as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, criterion_atten, valid_loader, converter_atten,
                        opt)
                model.train()
                writer.add_scalar('accuracy',
                                  current_accuracy,
                                  global_step=iteration + 1)

                # training loss and validation loss
                loss_log = f'[{iteration + 1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.exp_name}/best_accuracy.pth')
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.exp_name}/best_norm_ED.pth')
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                for gt, pred, confidence in zip(labels[:5], preds[:5],
                                                confidence_score[:5]):
                    gt = gt[:gt.find('[s]')]
                    pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')

        # save model per 1e+5 iter.
        if (iteration + 1) % 1e+5 == 0:
            torch.save(
                model.state_dict(),
                f'./saved_models/{opt.exp_name}/iter_{iteration + 1}.pth')

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()

        # if (iteration + 1) % opt.valInterval == 0:
        #     print(f'iter: {iteration + 1}\tloss: {cost}')
        iteration += 1
Exemple #14
0
def train(opt):
    """ training pipeline for our character recognition model """
    if not opt.data_filtering_off:
        print(
            "Filtering the images containing characters which are not in opt.character"
        )
        print(
            "Filtering the images whose label is longer than opt.batch_max_length"
        )

    opt.select_data = opt.select_data.split("-")
    opt.batch_ratio = opt.batch_ratio.split("-")
    train_dataset = Batch_Balanced_Dataset(opt)

    # Logging the experiment, so that we can refer to the performance of previous runs
    log = open(f"./saved_models/{opt.exp_name}/log_dataset.txt", "a")
    # Using params from user input to collation function for dataloader
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    # Defining our validation dataloader
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True,
    )
    log.write(valid_dataset_log)
    print("-" * 80)
    log.write("-" * 80 + "\n")
    log.close()

    # Using either CTC or Attention for char predictions
    if "CTC" in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)
    # Runnning our OCR model in grayscale or RGB
    if opt.rgb:
        opt.input_channel = 3
    # Defining our model using user inputs
    model = Model(opt)
    print(
        "model input parameters",
        opt.imgH,
        opt.imgW,
        opt.num_fiducial,
        opt.input_channel,
        opt.output_channel,
        opt.hidden_size,
        opt.num_class,
        opt.batch_max_length,
        opt.Transformation,
        opt.FeatureExtraction,
        opt.SequenceModeling,
        opt.Prediction,
    )

    # weight initialization
    for name, param in model.named_parameters():
        if "localization_fc2" in name:
            print(f"Skip {name} as it is already initialized")
            continue
        try:
            if "bias" in name:
                init.constant_(param, 0.0)
            elif "weight" in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if "weight" in name:
                param.data.fill_(1)
            continue
    # Putting model in training mode
    model.train()
    # Using finetuning saved model from previous runs
    if opt.saved_model != "":
        print(f"loading pretrained model from {opt.saved_model}")
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_model), strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_model))
    print("Model:")
    # print(model)
    # Sending model to cpu or gpu, depending upon the avialbility
    model.to(device)

    # Setting up loss functions in the case of either CTC or Attention
    if "CTC" in opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
            device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print("Trainable params num : ", sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # Setup of optimizer to be used
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    print(optimizer)

    # print(opt)
    with open(f"./saved_models/{opt.exp_name}/opt.txt", "a") as opt_file:
        opt_log = "------------ Options -------------\n"
        args = vars(opt)
        for k, v in args.items():
            opt_log += f"{str(k)}: {str(v)}\n"
        opt_log += "---------------------------------------\n"
        print(opt_log)
        opt_file.write(opt_log)

    # Training iteration starts here
    start_iter = 0
    if opt.saved_model != "":
        try:
            start_iter = int(opt.saved_model.split("_")[-1].split(".")[0])
            print(f"continue to train, start_iter: {start_iter}")
        except:
            pass

    # Setting up initial metrics results and initializing the timer
    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    iteration = start_iter

    while True:
        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        text, length = converter.encode(labels,
                                        batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)

        if "CTC" in opt.Prediction:
            preds = model(image, text)
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            preds = preds.log_softmax(2).permute(1, 0, 2)
            cost = criterion(preds, text, preds_size, length)

        else:
            preds = model(image, text[:, :-1])  # align with Attention.forward
            target = text[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.view(-1, preds.shape[-1]),
                             target.contiguous().view(-1))

        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)

        # validation part
        if (
                iteration + 1
        ) % opt.valInterval == 0 or iteration == 0:  # To see training progress, we also conduct validation when 'iteration == 0'
            elapsed_time = time.time() - start_time
            # for log
            with open(f"./saved_models/{opt.exp_name}/log_train.txt",
                      "a") as log:
                model.eval()
                with torch.no_grad():
                    (
                        valid_loss,
                        current_accuracy,
                        current_norm_ED,
                        preds,
                        confidence_score,
                        labels,
                        infer_time,
                        length_of_data,
                    ) = validation(model, criterion, valid_loader, converter,
                                   opt)
                model.train()

                # training loss and validation loss
                loss_log = f"[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}"
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(
                        model.state_dict(),
                        f"./saved_models/{opt.exp_name}/best_accuracy.pth",
                    )
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(
                        model.state_dict(),
                        f"./saved_models/{opt.exp_name}/best_norm_ED.pth",
                    )
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f"{loss_log}\n{current_model_log}\n{best_model_log}"
                print(loss_model_log)
                log.write(loss_model_log + "\n")

                # show some predicted results
                dashed_line = "-" * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f"{dashed_line}\n{head}\n{dashed_line}\n"
                for gt, pred, confidence in zip(labels[:5], preds[:5],
                                                confidence_score[:5]):
                    if "Attn" in opt.Prediction:
                        gt = gt[:gt.find("[s]")]
                        pred = pred[:pred.find("[s]")]

                    predicted_result_log += f"{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n"
                predicted_result_log += f"{dashed_line}"
                print(predicted_result_log)
                log.write(predicted_result_log + "\n")

        # save model per 1e+5 iter.
        if (iteration + 1) % 1e5 == 0:
            torch.save(
                model.state_dict(),
                f"./saved_models/{opt.exp_name}/iter_{iteration+1}.pth",
            )

        if (iteration + 1) == opt.num_iter:
            print("end the training")
            sys.exit()
        iteration += 1