def train(opt):
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt, mode='Train')

    log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)

    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt, mode='Val')
    # valid_dataset = RawDataset(root=opt.vAlignCollatealid_data, opt=opt)  # use RawDataset
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        # shuffle=True,  # 'True' to check training progress with validation function.
        shuffle=False,
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)

    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    """ model configuration """
    if 'CTC' in opt.Prediction:
        if opt.baiduCTC:
            converter = CTCLabelConverterForBaiduWarpctc(opt.character)
        else:
            converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    model = torch.nn.DataParallel(model).to(device)
    model.train()
    if opt.saved_model != '':
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_model), strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_model))
    print("Model:")
    print(model)
    """ setup loss """
    if 'CTC' in opt.Prediction:
        if opt.baiduCTC:
            # need to install warpctc. see our guideline.
            from warpctc_pytorch import CTCLoss
            criterion = CTCLoss()
        else:
            criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
            device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    print(optimizer)
    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0
    if opt.saved_model != '':
        try:
            start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    iteration = start_iter
    start_time = time.time()

    while (True):
        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        text, length = converter.encode(labels,
                                        batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)

        if 'CTC' in opt.Prediction:
            preds = model(image, text)
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            if opt.baiduCTC:
                preds = preds.permute(1, 0, 2)  # to use CTCLoss format
                cost = criterion(preds, text, preds_size, length) / batch_size
            else:
                preds = preds.log_softmax(2).permute(1, 0, 2)
                cost = criterion(preds, text, preds_size, length)

        else:
            preds = model(image, text[:, :-1])  # align with Attention.forward
            target = text[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.view(-1, preds.shape[-1]),
                             target.contiguous().view(-1))

        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)

        if iteration % 20 == 0:
            print(f'#{iteration}: loss:{cost}', flush=True)

        # validation part
        if (
                iteration + 1
        ) % opt.valInterval == 0 or iteration == 0:  # To see training progress, we also conduct validation when 'iteration == 0'
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.exp_name}/log_train.txt',
                      'a') as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, criterion, valid_loader, converter, opt)
                model.train()

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.exp_name}/best_accuracy.pth')
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.exp_name}/best_norm_ED.pth')
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                for gt, pred, confidence in zip(labels[:5], preds[:5],
                                                confidence_score[:5]):
                    if 'Attn' in opt.Prediction:
                        gt = gt[:gt.find('[s]')]
                        pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')

        # save model per 1e+5 iter.
        if (iteration + 1) % 1e+5 == 0:
            torch.save(
                model.state_dict(),
                f'./saved_models/{opt.exp_name}/iter_{iteration+1}.pth')

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
Ejemplo n.º 2
0
def train(opt):
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    model = torch.nn.DataParallel(model).to(device)
    model.train()
    if opt.saved_model != '':
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_model), strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_model))
    print("Model:")
    print(model)
    """ setup loss """
    if 'CTC' in opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
            device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    print(optimizer)
    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.experiment_name}/opt.txt',
              'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0
    if opt.saved_model != '':
        try:
            start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    best_loss = float("inf")
    i = start_iter

    if opt.early_stopping_param == 'accuracy':
        early_stopping_param = opt.early_stopping_param
    elif opt.early_stopping_param == 'loss':
        early_stopping_param = opt.early_stopping_param
    elif opt.early_stopping_param == 'None':
        early_stopping_param = None
    else:
        raise Exception(
            'early_stopping_param is neither accuracy, loss or None')

    current_patience = opt.early_stopping_patience

    while (True):
        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        text, length = converter.encode(labels,
                                        batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)
        if 'CTC' in opt.Prediction:
            preds = model(image, text).log_softmax(2)
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            preds = preds.permute(1, 0, 2)

            # (ctc_a) For PyTorch 1.2.0 and 1.3.0. To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
            # https://github.com/jpuigcerver/PyLaia/issues/16
            torch.backends.cudnn.enabled = False
            cost = criterion(preds, text.to(device), preds_size.to(device),
                             length.to(device))
            torch.backends.cudnn.enabled = True

            # # (ctc_b) To reproduce our pretrained model / paper, use our previous code (below code) instead of (ctc_a).
            # # With PyTorch 1.2.0, the below code occurs NAN, so you may use PyTorch 1.1.0.
            # # Thus, the result of CTCLoss is different in PyTorch 1.1.0 and PyTorch 1.2.0.
            # # See https://github.com/clovaai/deep-text-recognition-benchmark/issues/56#issuecomment-526490707
            # cost = criterion(preds, text, preds_size, length)

        else:
            preds = model(image, text[:, :-1])  # align with Attention.forward
            target = text[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.view(-1, preds.shape[-1]),
                             target.contiguous().view(-1))

        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)

        # validation part
        if i > (opt.ignore_x_vals *
                opt.valInterval) and i % opt.valInterval == 0:
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.experiment_name}/log_train.txt',
                      'a') as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, criterion, valid_loader, converter, opt)
                model.train()

                # training loss and validation loss
                loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.experiment_name}/best_accuracy.pth'
                    )
                    if early_stopping_param == 'accuracy':
                        current_patience = opt.early_stopping_patience
                elif early_stopping_param == 'accuracy':  #if it enters here, it means that the accuracy didn't upgrade
                    current_patience -= 1

                if valid_loss < best_loss:
                    best_loss = valid_loss
                    if early_stopping_param == 'loss':
                        current_patience = opt.early_stopping_patience
                elif early_stopping_param == 'loss':  #if it enters here, it means that the loss didn't upgrade
                    current_patience -= 1

                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.experiment_name}/best_norm_ED.pth'
                    )
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'

                if early_stopping_param is not None:
                    loss_model_log += f'\n{"Early Stopping Param":17s}: {early_stopping_param}, {"Patience"}: {current_patience}'

                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                for gt, pred, confidence in zip(labels[:5], preds[:5],
                                                confidence_score[:5]):
                    if 'Attn' in opt.Prediction:
                        gt = gt[:gt.find('[s]')]
                        pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')

                if current_patience == 0:
                    print(
                        "Stopped the training due to early stopping patience")
                    sys.exit()

        # save model per 1e+5 iter.
        if (i + 1) % 1e+5 == 0:
            torch.save(model.state_dict(),
                       f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')

        if i == opt.num_iter:
            print('end the training')
            sys.exit()
        i += 1
Ejemplo n.º 3
0
def train(opt):
    """ training pipeline for our character recognition model """
    if not opt.data_filtering_off:
        print(
            "Filtering the images containing characters which are not in opt.character"
        )
        print(
            "Filtering the images whose label is longer than opt.batch_max_length"
        )

    opt.select_data = opt.select_data.split("-")
    opt.batch_ratio = opt.batch_ratio.split("-")
    train_dataset = Batch_Balanced_Dataset(opt)

    # Logging the experiment, so that we can refer to the performance of previous runs
    log = open(f"./saved_models/{opt.exp_name}/log_dataset.txt", "a")
    # Using params from user input to collation function for dataloader
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    # Defining our validation dataloader
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True,
    )
    log.write(valid_dataset_log)
    print("-" * 80)
    log.write("-" * 80 + "\n")
    log.close()

    # Using either CTC or Attention for char predictions
    if "CTC" in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)
    # Runnning our OCR model in grayscale or RGB
    if opt.rgb:
        opt.input_channel = 3
    # Defining our model using user inputs
    model = Model(opt)
    print(
        "model input parameters",
        opt.imgH,
        opt.imgW,
        opt.num_fiducial,
        opt.input_channel,
        opt.output_channel,
        opt.hidden_size,
        opt.num_class,
        opt.batch_max_length,
        opt.Transformation,
        opt.FeatureExtraction,
        opt.SequenceModeling,
        opt.Prediction,
    )

    # weight initialization
    for name, param in model.named_parameters():
        if "localization_fc2" in name:
            print(f"Skip {name} as it is already initialized")
            continue
        try:
            if "bias" in name:
                init.constant_(param, 0.0)
            elif "weight" in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if "weight" in name:
                param.data.fill_(1)
            continue
    # Putting model in training mode
    model.train()
    # Using finetuning saved model from previous runs
    if opt.saved_model != "":
        print(f"loading pretrained model from {opt.saved_model}")
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_model), strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_model))
    print("Model:")
    # print(model)
    # Sending model to cpu or gpu, depending upon the avialbility
    model.to(device)

    # Setting up loss functions in the case of either CTC or Attention
    if "CTC" in opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
            device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print("Trainable params num : ", sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # Setup of optimizer to be used
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    print(optimizer)

    # print(opt)
    with open(f"./saved_models/{opt.exp_name}/opt.txt", "a") as opt_file:
        opt_log = "------------ Options -------------\n"
        args = vars(opt)
        for k, v in args.items():
            opt_log += f"{str(k)}: {str(v)}\n"
        opt_log += "---------------------------------------\n"
        print(opt_log)
        opt_file.write(opt_log)

    # Training iteration starts here
    start_iter = 0
    if opt.saved_model != "":
        try:
            start_iter = int(opt.saved_model.split("_")[-1].split(".")[0])
            print(f"continue to train, start_iter: {start_iter}")
        except:
            pass

    # Setting up initial metrics results and initializing the timer
    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    iteration = start_iter

    while True:
        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        text, length = converter.encode(labels,
                                        batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)

        if "CTC" in opt.Prediction:
            preds = model(image, text)
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            preds = preds.log_softmax(2).permute(1, 0, 2)
            cost = criterion(preds, text, preds_size, length)

        else:
            preds = model(image, text[:, :-1])  # align with Attention.forward
            target = text[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.view(-1, preds.shape[-1]),
                             target.contiguous().view(-1))

        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)

        # validation part
        if (
                iteration + 1
        ) % opt.valInterval == 0 or iteration == 0:  # To see training progress, we also conduct validation when 'iteration == 0'
            elapsed_time = time.time() - start_time
            # for log
            with open(f"./saved_models/{opt.exp_name}/log_train.txt",
                      "a") as log:
                model.eval()
                with torch.no_grad():
                    (
                        valid_loss,
                        current_accuracy,
                        current_norm_ED,
                        preds,
                        confidence_score,
                        labels,
                        infer_time,
                        length_of_data,
                    ) = validation(model, criterion, valid_loader, converter,
                                   opt)
                model.train()

                # training loss and validation loss
                loss_log = f"[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}"
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(
                        model.state_dict(),
                        f"./saved_models/{opt.exp_name}/best_accuracy.pth",
                    )
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(
                        model.state_dict(),
                        f"./saved_models/{opt.exp_name}/best_norm_ED.pth",
                    )
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f"{loss_log}\n{current_model_log}\n{best_model_log}"
                print(loss_model_log)
                log.write(loss_model_log + "\n")

                # show some predicted results
                dashed_line = "-" * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f"{dashed_line}\n{head}\n{dashed_line}\n"
                for gt, pred, confidence in zip(labels[:5], preds[:5],
                                                confidence_score[:5]):
                    if "Attn" in opt.Prediction:
                        gt = gt[:gt.find("[s]")]
                        pred = pred[:pred.find("[s]")]

                    predicted_result_log += f"{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n"
                predicted_result_log += f"{dashed_line}"
                print(predicted_result_log)
                log.write(predicted_result_log + "\n")

        # save model per 1e+5 iter.
        if (iteration + 1) % 1e5 == 0:
            torch.save(
                model.state_dict(),
                f"./saved_models/{opt.exp_name}/iter_{iteration+1}.pth",
            )

        if (iteration + 1) == opt.num_iter:
            print("end the training")
            sys.exit()
        iteration += 1
Ejemplo n.º 4
0
def benchmark_all_eval(model,
                       criterion,
                       converter,
                       opt,
                       calculate_infer_time=False):
    """ evaluation with 10 benchmark evaluation datasets """
    # The evaluation datasets, dataset order is same with Table 1 in our paper.
    eval_data_list = [
        'IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857', 'IC13_1015',
        'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80'
    ]

    # # To easily compute the total accuracy of our paper.
    # eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_867',
    #                   'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80']

    if calculate_infer_time:
        evaluation_batch_size = 1  # batch_size should be 1 to calculate the GPU inference time per image.
    else:
        evaluation_batch_size = opt.batch_size

    list_accuracy = []
    total_forward_time = 0
    total_evaluation_data_number = 0
    total_correct_number = 0
    log = open(f'./result/{opt.exp_name}/log_all_evaluation.txt', 'a')
    dashed_line = '-' * 80
    print(dashed_line)
    log.write(dashed_line + '\n')
    for eval_data in eval_data_list:
        eval_data_path = os.path.join(opt.eval_data, eval_data)
        AlignCollate_evaluation = AlignCollate(imgH=opt.imgH,
                                               imgW=opt.imgW,
                                               keep_ratio_with_pad=opt.PAD)
        eval_data, eval_data_log = hierarchical_dataset(root=eval_data_path,
                                                        opt=opt)
        evaluation_loader = torch.utils.data.DataLoader(
            eval_data,
            batch_size=evaluation_batch_size,
            shuffle=False,
            num_workers=int(opt.workers),
            collate_fn=AlignCollate_evaluation,
            pin_memory=True)

        _, accuracy_by_best_model, norm_ED_by_best_model, _, _, _, infer_time, length_of_data = validation(
            model, criterion, evaluation_loader, converter, opt)
        list_accuracy.append(f'{accuracy_by_best_model:0.3f}')
        total_forward_time += infer_time
        total_evaluation_data_number += len(eval_data)
        total_correct_number += accuracy_by_best_model * length_of_data
        log.write(eval_data_log)
        print(
            f'Acc {accuracy_by_best_model:0.3f}\t normalized_ED {norm_ED_by_best_model:0.3f}'
        )
        log.write(
            f'Acc {accuracy_by_best_model:0.3f}\t normalized_ED {norm_ED_by_best_model:0.3f}\n'
        )
        print(dashed_line)
        log.write(dashed_line + '\n')

    averaged_forward_time = total_forward_time / total_evaluation_data_number * 1000
    total_accuracy = total_correct_number / total_evaluation_data_number
    params_num = sum([np.prod(p.size()) for p in model.parameters()])

    evaluation_log = 'accuracy: '
    for name, accuracy in zip(eval_data_list, list_accuracy):
        evaluation_log += f'{name}: {accuracy}\t'
    evaluation_log += f'total_accuracy: {total_accuracy:0.3f}\t'
    evaluation_log += f'averaged_infer_time: {averaged_forward_time:0.3f}\t# parameters: {params_num/1e6:0.3f}'
    print(evaluation_log)
    log.write(evaluation_log + '\n')
    log.close()

    return None
Ejemplo n.º 5
0
def detect_ocr(config, image, timestamp, save_img):

    detection_list, img, boxes = Detection_txt(config, image, config.net)

    # print(detection_list)
    t = time.time()

    device = config.device
    model = config.model
    converter = config.converter

    # 32 * 100
    AlignCollate_demo = AlignCollate(imgH=config.imgH,
                                     imgW=config.imgW,
                                     keep_ratio_with_pad=config.PAD)
    # demo_data = RawDataset(root=image, opt=config)  # use RawDataset
    demo_data = RawDataset_wPosition(root=detection_list,
                                     opt=config)  # use RawDataset

    demo_loader = torch.utils.data.DataLoader(demo_data,
                                              batch_size=config.batch_size,
                                              shuffle=False,
                                              num_workers=int(config.workers),
                                              collate_fn=AlignCollate_demo,
                                              pin_memory=True)

    #(< PIL.Image.Image image mode=L size=398x120 at 0x7F376DAF30B8 >, './demo_image/demo_12.png')

    # predict
    model.eval()
    with torch.no_grad():
        log = open(f'{config.logfilepath}', 'a')
        dashed_line = '-' * 80
        head = f'{"coordinates":25s}\t{"predicted_labels":25s}\tconfidence score'
        if save_img: print(f'{dashed_line}\n{head}\n{dashed_line}')
        log.write(f'{dashed_line}\n{head}\n{dashed_line}\n')

        pred_list = []
        new_boxes = []
        for image_tensors, coordinate_list in demo_loader:
            batch_size = image_tensors.size(0)
            # print(image_tensors.shape)

            image = image_tensors.to(device)
            # For max length prediction
            length_for_pred = torch.IntTensor([config.batch_max_length] *
                                              batch_size).to(device)
            text_for_pred = torch.LongTensor(
                batch_size, config.batch_max_length + 1).fill_(0).to(device)

            preds = model(image, text_for_pred, is_train=False)

            # select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            preds_str = converter.decode(preds_index, length_for_pred)

            preds_prob = F.softmax(preds, dim=2)
            preds_max_prob, _ = preds_prob.max(dim=2)

            for coordinate, pred, pred_max_prob in zip(coordinate_list,
                                                       preds_str,
                                                       preds_max_prob):

                pred_EOS = pred.find('[s]')

                pred = pred[:
                            pred_EOS]  # prune after "end of sentence" token ([s])

                if pred_EOS == 0: confidence_score = 0.0
                else:
                    pred_max_prob = pred_max_prob[:pred_EOS]
                    # calculate confidence score (= multiply of pred_max_prob)
                    confidence_score = pred_max_prob.cumprod(dim=0)[-1].item()
                coordinate = list(coordinate)
                pred_list.append([coordinate, pred, confidence_score])
                if save_img:
                    print(f'{coordinate}\t{pred:25s}\t{confidence_score:0.4f}')
                log.write(
                    f'{coordinate}\t{pred:25s}\t{confidence_score:0.4f}\n')

        log.close()
    recog_time = time.time() - t
    config.recog_time = config.recog_time + recog_time

    # print("\nrun time (recognition) : {:.2f} , {:.2f} s".format(recog_time,config.recog_time))

    if save_img:
        saveResult(img, boxes, pred_list, config.result_folder,
                   config.res_imagefileName)

    return pred_list, timestamp
def test(opt):
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)
    # model = torch.nn.DataParallel(model).to(device)

    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model, map_location=device))
    opt.experiment_name = '_'.join(opt.saved_model.split('/')[1:])
    # print(model)

    if opt.continue_model != '':
        print(f'loading pretrained model from {opt.continue_model}')
        model.load_state_dict(torch.load(opt.continue_model))
    """ keep evaluation model and result logs """
    os.makedirs(f'./result/{opt.experiment_name}', exist_ok=True)
    os.system(f'cp {opt.saved_model} ./result/{opt.experiment_name}/')
    """ setup loss """
    if 'CTC' in opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
            device)  # ignore [GO] token = ignore index 0
    """ evaluation """
    model.eval()
    with torch.no_grad():
        if opt.benchmark_all_eval:  # evaluation with 10 benchmark evaluation datasets
            benchmark_all_eval(model, criterion, converter, opt)
        else:
            log = open(f'./result/{opt.experiment_name}/log_evaluation.txt',
                       'a')
            AlignCollate_evaluation = AlignCollate(imgH=opt.imgH,
                                                   imgW=opt.imgW,
                                                   keep_ratio_with_pad=opt.PAD)
            eval_data, eval_data_log = hierarchical_dataset(root=opt.eval_data,
                                                            opt=opt)
            evaluation_loader = torch.utils.data.DataLoader(
                eval_data,
                batch_size=opt.batch_size,
                shuffle=False,
                num_workers=int(opt.workers),
                collate_fn=AlignCollate_evaluation,
                pin_memory=True)
            _, accuracy_by_best_model, _, _, _, _, _, _ = validation(
                model, criterion, evaluation_loader, converter, opt)
            log.write(eval_data_log)
            print(f'{accuracy_by_best_model:0.3f}')
            log.write(f'{accuracy_by_best_model:0.3f}\n')
            log.close()
Ejemplo n.º 7
0
def demo(opt):
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)
    model = torch.nn.DataParallel(model).to(device)

    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model, map_location=device))

    # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
    AlignCollate_demo = AlignCollate(imgH=opt.imgH,
                                     imgW=opt.imgW,
                                     keep_ratio_with_pad=opt.PAD)
    demo_data = RawDataset(root=opt.image_folder, opt=opt)  # use RawDataset
    demo_loader = torch.utils.data.DataLoader(demo_data,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=int(opt.workers),
                                              collate_fn=AlignCollate_demo,
                                              pin_memory=True)

    # predict
    model.eval()
    with torch.no_grad():
        for image_tensors, image_path_list in demo_loader:
            batch_size = image_tensors.size(0)
            image = image_tensors.to(device)
            # For max length prediction
            length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                              batch_size).to(device)
            text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                             1).fill_(0).to(device)

            if 'CTC' in opt.Prediction:
                preds = model(image, text_for_pred).log_softmax(2)

                # Select max probabilty (greedy decoding) then decode index to character
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                _, preds_index = preds.permute(1, 0, 2).max(2)
                preds_index = preds_index.transpose(1, 0).contiguous().view(-1)
                preds_str = converter.decode(preds_index.data, preds_size.data)

            else:
                preds = model(image, text_for_pred, is_train=False)

                # select max probabilty (greedy decoding) then decode index to character
                _, preds_index = preds.max(2)
                preds_str = converter.decode(preds_index, length_for_pred)

            print('-' * 80)
            print('image_path\tpredicted_labels')
            print('-' * 80)

            temp = []

            for img_name, pred in zip(image_path_list, preds_str):
                if 'Attn' in opt.Prediction:
                    pred = pred[:pred.find(
                        '[s]')]  # prune after "end of sentence" token ([s])
                    temp.append(pred)
                print(f'{img_name}\t{pred}')
    return temp
Ejemplo n.º 8
0
def _textRecognition(opt):
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)
    model = torch.nn.DataParallel(model).to(device)

    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model, map_location=device))

    # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
    AlignCollate_demo = AlignCollate(imgH=opt.imgH,
                                     imgW=opt.imgW,
                                     keep_ratio_with_pad=opt.PAD)
    demo_data = RawDataset(root=opt.image_folder, opt=opt)  # use RawDataset
    demo_loader = torch.utils.data.DataLoader(demo_data,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=int(opt.workers),
                                              collate_fn=AlignCollate_demo,
                                              pin_memory=True)

    # predict
    char_list = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"

    csv_filename = os.path.join(opt.output_dirpath, "text_information.csv")
    Header = ["bookID", "prediction"]

    with open(csv_filename, 'w') as f:
        writer = csv.DictWriter(f, fieldnames=Header)
        writer.writeheader()

        model.eval()
        with torch.no_grad():
            for image_tensors, image_path_list in demo_loader:
                batch_size = image_tensors.size(0)
                image = image_tensors.to(device)
                # For max length prediction
                length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                                  batch_size).to(device)
                text_for_pred = torch.LongTensor(
                    batch_size, opt.batch_max_length + 1).fill_(0).to(device)

                if 'CTC' in opt.Prediction:
                    preds = model(image, text_for_pred)

                    # Select max probabilty (greedy decoding) then decode index to character
                    preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                    _, preds_index = preds.max(2)
                    preds_index = preds_index.view(-1)
                    preds_str = converter.decode(preds_index.data,
                                                 preds_size.data)

                else:
                    preds = model(image, text_for_pred, is_train=False)

                    # select max probabilty (greedy decoding) then decode index to character
                    _, preds_index = preds.max(2)
                    preds_str = converter.decode(preds_index, length_for_pred)

                #log = open(f'./text_information.csv', 'a')
                dashed_line = '-' * 80
                head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score'

                print(f'{dashed_line}\n{head}\n{dashed_line}')

                preds_prob = F.softmax(preds, dim=2)
                preds_max_prob, _ = preds_prob.max(dim=2)
                for img_name, pred, pred_max_prob in zip(
                        image_path_list, preds_str, preds_max_prob):
                    if 'Attn' in opt.Prediction:
                        pred_EOS = pred.find('[s]')
                        pred = pred[:
                                    pred_EOS]  # prune after "end of sentence" token ([s])
                        pred_max_prob = pred_max_prob[:pred_EOS]

                    # calculate confidence score (= multiply of pred_max_prob)
                    #print("{}:{}".format(pred_max_prob,len(pred_max_prob)))
                    try:
                        confidence_score = pred_max_prob.cumprod(dim=0)[-1]
                    except:
                        deleteImageAndText(opt.book_img_dirpath, img_name)
                        continue
                    #confidence_score = pred_max_prob.cumprod(dim=0)[-1]
                    pred = pred[0]

                    if confidence_score < 0.5:
                        pred = "Unreadble"
                        deleteImageAndText(opt.book_img_dirpath, img_name)
                        continue

                    elif not (pred in char_list):
                        pred = "Undefined"

                    # extract the name part of the image
                    filename = os.path.basename(img_name)

                    print(
                        f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}')

                    writer.writerow({"bookID": filename, "prediction": pred})
Ejemplo n.º 9
0
def demo(opt):
    inputimage = opt.input_image
    boxesscv = opt.boxescsv
    bboxes = parse_csv(inputimage, boxesscv)
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)
    model = torch.nn.DataParallel(model).to(device)

    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model, map_location=device))

    # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
    AlignCollate_demo = AlignCollate(imgH=opt.imgH,
                                     imgW=opt.imgW,
                                     keep_ratio_with_pad=opt.PAD)
    demo_data = RawDataset(root=opt.image_folder, opt=opt)  # use RawDataset
    demo_loader = torch.utils.data.DataLoader(demo_data,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=int(opt.workers),
                                              collate_fn=AlignCollate_demo,
                                              pin_memory=True)

    # predict
    model.eval()
    with torch.no_grad():
        for image_tensors, image_path_list in demo_loader:
            batch_size = image_tensors.size(0)
            image = image_tensors.to(device)
            # For max length prediction
            length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                              batch_size).to(device)
            text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                             1).fill_(0).to(device)

            if 'CTC' in opt.Prediction:
                preds = model(image, text_for_pred)

                # Select max probabilty (greedy decoding) then decode index to character
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                _, preds_index = preds.max(2)
                # preds_index = preds_index.view(-1)
                preds_str = converter.decode(preds_index, preds_size)

            else:
                preds = model(image, text_for_pred, is_train=False)

                # select max probabilty (greedy decoding) then decode index to character
                _, preds_index = preds.max(2)
                preds_str = converter.decode(preds_index, length_for_pred)

            log = open(f'{opt.output_folder}result.csv', 'w')

            preds_prob = F.softmax(preds, dim=2)
            preds_max_prob, _ = preds_prob.max(dim=2)
            for img_index, (pred, pred_max_prob) in enumerate(
                    zip(preds_str, preds_max_prob)):
                if 'Attn' in opt.Prediction:
                    pred_EOS = pred.find('[s]')
                    pred = pred[:
                                pred_EOS]  # prune after "end of sentence" token ([s])
                    pred_max_prob = pred_max_prob[:pred_EOS]

                # calculate confidence score (= multiply of pred_max_prob)
                confidence_score = pred_max_prob.cumprod(dim=0)[-1]
                for pts in bboxes[img_index]:
                    x, y = pts
                    log.write(f'{x},{y},')
                log.write(f'{pred}\n')

            log.close()
            # copy log to local output folder
            os.system(f'cp {opt.output_folder}result.csv /input/output')
            shutil.make_archive('per_word_visual', 'zip', '/input/output')
Ejemplo n.º 10
0
def demo(opt):
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
          opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
          opt.SequenceModeling, opt.Prediction)
    model = torch.nn.DataParallel(model).to(device)

    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model, map_location=device))

    # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
    AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
    demo_data = RawDataset(root=opt.image_folder, opt=opt)  # use RawDataset
    demo_loader = torch.utils.data.DataLoader(
        demo_data, batch_size=opt.batch_size,
        shuffle=False,
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_demo, pin_memory=True)

    # predict
    acc = 0;avgconf = [0.0, 0.0]  # correct, wrong
    sample_num =0
    model.eval()
    with torch.no_grad():
        for image_tensors, image_path_list in demo_loader:

            batch_size = image_tensors.size(0)
            image = image_tensors.to(device)
            # For max length prediction
            length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
            text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)

            if 'CTC' in opt.Prediction:
                preds = model(image, text_for_pred)

                # Select max probabilty (greedy decoding) then decode index to character
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                _, preds_index = preds.max(2)
                preds_index = preds_index.view(-1)
                preds_str = converter.decode(preds_index.data, preds_size.data)

            else:
                preds = model(image, text_for_pred, is_train=False)

                # select max probabilty (greedy decoding) then decode index to character
                _, preds_index = preds.max(2)
                preds_str = converter.decode(preds_index, length_for_pred)


            log = open(f'./log_demo_result.txt', 'a')
            dashed_line = '-' * 80
            head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score'

            print(f'{dashed_line}\n{head}\n{dashed_line}')
            log.write(f'{dashed_line}\n{head}\n{dashed_line}\n')

            preds_prob = F.softmax(preds, dim=2)
            preds_max_prob, _ = preds_prob.max(dim=2)

            for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob):
                try:
                    if 'Attn' in opt.Prediction:
                        pred_EOS = pred.find('[s]')
                        pred = pred[:pred_EOS]  # prune after "end of sentence" token ([s])
                        pred_max_prob = pred_max_prob[:pred_EOS]

                    # calculate confidence score (= multiply of pred_max_prob)
                    confidence_score = pred_max_prob.cumprod(dim=0)[-1]

                    print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}')
                    log.write(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n')

                    if os.path.split(img_name)[-1].split("_")[0]==pred:
                        acc+=1.0
                        avgconf[0]+=confidence_score
                    else:
                        avgconf[1] += confidence_score
                except:
                    avgconf[1] += confidence_score
                    print("!!!!!",img_name)

            log.close()
            sample_num+=len(image_path_list)



    acc/=sample_num
    avgconf[0]/=sample_num
    avgconf[1]/=sample_num
    print(f"Sample num:{sample_num}\nAccuracy:{acc}\nCorrect samples Average Confidence:{avgconf[0]}\nWrong samples Average Confidence:{avgconf[1]}")
Ejemplo n.º 11
0
def train(opt):
    plotDir = os.path.join(opt.exp_dir, opt.exp_name, 'plots')
    if not os.path.exists(plotDir):
        os.makedirs(plotDir)

    lib.print_model_settings(locals().copy())
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')

    #considering the real images for discriminator
    opt.batch_size = opt.batch_size * 2

    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(os.path.join(opt.exp_dir, opt.exp_name, 'log_dataset.txt'), 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        False,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3

    model = AdaINGenV4(opt)
    ocrModel = Model(opt)
    disModel = MsImageDisV1(opt)

    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)

    #  weight initialization
    for currModel in [model, ocrModel, disModel]:
        for name, param in currModel.named_parameters():
            if 'localization_fc2' in name:
                print(f'Skip {name} as it is already initialized')
                continue
            try:
                if 'bias' in name:
                    init.constant_(param, 0.0)
                elif 'weight' in name:
                    init.kaiming_normal_(param)
            except Exception as e:  # for batchnorm.
                if 'weight' in name:
                    param.data.fill_(1)
                continue

    # data parallel for multi-GPU
    ocrModel = torch.nn.DataParallel(ocrModel).to(device)
    if not opt.ocrFixed:
        ocrModel.train()
    else:
        ocrModel.module.Transformation.eval()
        ocrModel.module.FeatureExtraction.eval()
        ocrModel.module.AdaptiveAvgPool.eval()
        # ocrModel.module.SequenceModeling.eval()
        ocrModel.module.Prediction.eval()

    model = torch.nn.DataParallel(model).to(device)
    model.train()

    disModel = torch.nn.DataParallel(disModel).to(device)
    disModel.train()

    if opt.modelFolderFlag:

        if len(
                glob.glob(
                    os.path.join(opt.exp_dir, opt.exp_name,
                                 "iter_*_synth.pth"))) > 0:
            opt.saved_synth_model = glob.glob(
                os.path.join(opt.exp_dir, opt.exp_name,
                             "iter_*_synth.pth"))[-1]

        if len(
                glob.glob(
                    os.path.join(opt.exp_dir, opt.exp_name,
                                 "iter_*_dis.pth"))) > 0:
            opt.saved_dis_model = glob.glob(
                os.path.join(opt.exp_dir, opt.exp_name, "iter_*_dis.pth"))[-1]

    #loading pre-trained model
    if opt.saved_ocr_model != '' and opt.saved_ocr_model != 'None':
        print(f'loading pretrained ocr model from {opt.saved_ocr_model}')
        if opt.FT:
            ocrModel.load_state_dict(torch.load(opt.saved_ocr_model),
                                     strict=False)
        else:
            ocrModel.load_state_dict(torch.load(opt.saved_ocr_model))
    print("OCRModel:")
    print(ocrModel)

    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        print(f'loading pretrained synth model from {opt.saved_synth_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_synth_model),
                                  strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_synth_model))
    print("SynthModel:")
    print(model)

    if opt.saved_dis_model != '' and opt.saved_dis_model != 'None':
        print(
            f'loading pretrained discriminator model from {opt.saved_dis_model}'
        )
        if opt.FT:
            disModel.load_state_dict(torch.load(opt.saved_dis_model),
                                     strict=False)
        else:
            disModel.load_state_dict(torch.load(opt.saved_dis_model))
    print("DisModel:")
    print(disModel)
    """ setup loss """
    if 'CTC' in opt.Prediction:
        ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
            device)  # ignore [GO] token = ignore index 0

    recCriterion = torch.nn.L1Loss()
    styleRecCriterion = torch.nn.L1Loss()

    # loss averager
    loss_avg_ocr = Averager()
    loss_avg = Averager()
    loss_avg_dis = Averager()

    loss_avg_ocrRecon_1 = Averager()
    loss_avg_ocrRecon_2 = Averager()
    loss_avg_gen = Averager()
    loss_avg_imgRecon = Averager()
    loss_avg_styRecon = Averager()

    ##---------------------------------------##
    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.optim == 'adam':
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, opt.beta2),
                               weight_decay=opt.weight_decay)
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps,
                                   weight_decay=opt.weight_decay)
    print("SynthOptimizer:")
    print(optimizer)

    #filter parameters for OCR training
    ocr_filtered_parameters = []
    ocr_params_num = []
    for p in filter(lambda p: p.requires_grad, ocrModel.parameters()):
        ocr_filtered_parameters.append(p)
        ocr_params_num.append(np.prod(p.size()))
    print('OCR Trainable params num : ', sum(ocr_params_num))

    # setup optimizer
    if opt.optim == 'adam':
        ocr_optimizer = optim.Adam(ocr_filtered_parameters,
                                   lr=opt.lr,
                                   betas=(opt.beta1, opt.beta2),
                                   weight_decay=opt.weight_decay)
    else:
        ocr_optimizer = optim.Adadelta(ocr_filtered_parameters,
                                       lr=opt.lr,
                                       rho=opt.rho,
                                       eps=opt.eps,
                                       weight_decay=opt.weight_decay)
    print("OCROptimizer:")
    print(ocr_optimizer)

    #filter parameters for OCR training
    dis_filtered_parameters = []
    dis_params_num = []
    for p in filter(lambda p: p.requires_grad, disModel.parameters()):
        dis_filtered_parameters.append(p)
        dis_params_num.append(np.prod(p.size()))
    print('Dis Trainable params num : ', sum(dis_params_num))

    # setup optimizer
    if opt.optim == 'adam':
        dis_optimizer = optim.Adam(dis_filtered_parameters,
                                   lr=opt.lr,
                                   betas=(opt.beta1, opt.beta2),
                                   weight_decay=opt.weight_decay)
    else:
        dis_optimizer = optim.Adadelta(dis_filtered_parameters,
                                       lr=opt.lr,
                                       rho=opt.rho,
                                       eps=opt.eps,
                                       weight_decay=opt.weight_decay)
    print("DisOptimizer:")
    print(dis_optimizer)
    ##---------------------------------------##
    """ final options """
    with open(os.path.join(opt.exp_dir, opt.exp_name, 'opt.txt'),
              'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0

    if opt.saved_synth_model != '' and opt.saved_synth_model != 'None':
        try:
            start_iter = int(
                opt.saved_synth_model.split('_')[-2].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    #get schedulers
    scheduler = get_scheduler(optimizer, opt)
    ocr_scheduler = get_scheduler(ocr_optimizer, opt)
    dis_scheduler = get_scheduler(dis_optimizer, opt)

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    best_accuracy_ocr = -1
    best_norm_ED_ocr = -1
    iteration = start_iter
    cntr = 0

    while (True):
        # train part

        if opt.lr_policy != "None":
            scheduler.step()
            ocr_scheduler.step()
            dis_scheduler.step()

        image_tensors_all, labels_1_all, labels_2_all = train_dataset.get_batch(
        )

        # ## comment
        # pdb.set_trace()
        # for imgCntr in range(image_tensors.shape[0]):
        #     save_image(tensor2im(image_tensors[imgCntr]),'temp/'+str(imgCntr)+'.png')
        # pdb.set_trace()
        # ###
        # print(cntr)
        cntr += 1
        disCnt = int(image_tensors_all.size(0) / 2)
        image_tensors, image_tensors_real, labels_gt, labels_2 = image_tensors_all[:disCnt], image_tensors_all[
            disCnt:disCnt +
            disCnt], labels_1_all[:disCnt], labels_2_all[:disCnt]

        image = image_tensors.to(device)
        image_real = image_tensors_real.to(device)
        batch_size = image.size(0)

        ##-----------------------------------##
        #generate text(labels) from ocr.forward
        if opt.ocrFixed:
            # ocrModel.eval()
            length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                              batch_size).to(device)
            text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                             1).fill_(0).to(device)

            if 'CTC' in opt.Prediction:
                preds = ocrModel(image, text_for_pred)
                preds = preds[:, :text_for_loss.shape[1] - 1, :]
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                _, preds_index = preds.max(2)
                labels_1 = converter.decode(preds_index.data, preds_size.data)
            else:
                preds = ocrModel(image, text_for_pred, is_train=False)
                _, preds_index = preds.max(2)
                labels_1 = converter.decode(preds_index, length_for_pred)
                for idx, pred in enumerate(labels_1):
                    pred_EOS = pred.find('[s]')
                    labels_1[
                        idx] = pred[:
                                    pred_EOS]  # prune after "end of sentence" token ([s])
            # ocrModel.train()
        else:
            labels_1 = labels_gt

        ##-----------------------------------##

        text_1, length_1 = converter.encode(
            labels_1, batch_max_length=opt.batch_max_length)
        text_2, length_2 = converter.encode(
            labels_2, batch_max_length=opt.batch_max_length)

        #forward pass from style and word generator
        images_recon_1, images_recon_2, style = model(image, text_1, text_2)

        if 'CTC' in opt.Prediction:

            if not opt.ocrFixed:
                #ocr training with orig image
                preds_ocr = ocrModel(image, text_1)
                preds_size_ocr = torch.IntTensor([preds_ocr.size(1)] *
                                                 batch_size)
                preds_ocr = preds_ocr.log_softmax(2).permute(1, 0, 2)

                ocrCost_train = ocrCriterion(preds_ocr, text_1, preds_size_ocr,
                                             length_1)

            #content loss for reconstructed images
            preds_1 = ocrModel(images_recon_1, text_1)
            preds_size_1 = torch.IntTensor([preds_1.size(1)] * batch_size)
            preds_1 = preds_1.log_softmax(2).permute(1, 0, 2)

            preds_2 = ocrModel(images_recon_2, text_2)
            preds_size_2 = torch.IntTensor([preds_2.size(1)] * batch_size)
            preds_2 = preds_2.log_softmax(2).permute(1, 0, 2)
            ocrCost_1 = ocrCriterion(preds_1, text_1, preds_size_1, length_1)
            ocrCost_2 = ocrCriterion(preds_2, text_2, preds_size_2, length_2)
            # ocrCost = 0.5*( ocrCost_1 + ocrCost_2 )

        else:
            if not opt.ocrFixed:
                #ocr training with orig image
                preds_ocr = ocrModel(
                    image, text_1[:, :-1])  # align with Attention.forward
                target_ocr = text_1[:, 1:]  # without [GO] Symbol

                ocrCost_train = ocrCriterion(
                    preds_ocr.view(-1, preds_ocr.shape[-1]),
                    target_ocr.contiguous().view(-1))

            #content loss for reconstructed images
            preds_1 = ocrModel(images_recon_1, text_1[:, :-1],
                               is_train=False)  # align with Attention.forward
            target_1 = text_1[:, 1:]  # without [GO] Symbol

            preds_2 = ocrModel(images_recon_2, text_2[:, :-1],
                               is_train=False)  # align with Attention.forward
            target_2 = text_2[:, 1:]  # without [GO] Symbol

            ocrCost_1 = ocrCriterion(preds_1.view(-1, preds_1.shape[-1]),
                                     target_1.contiguous().view(-1))
            ocrCost_2 = ocrCriterion(preds_2.view(-1, preds_2.shape[-1]),
                                     target_2.contiguous().view(-1))
            # ocrCost = 0.5*(ocrCost_1+ocrCost_2)

        if not opt.ocrFixed:
            #training OCR
            ocrModel.zero_grad()
            ocrCost_train.backward()
            # torch.nn.utils.clip_grad_norm_(ocrModel.parameters(), opt.grad_clip)  # gradient clipping with 5 (Default)
            ocr_optimizer.step()
            #if ocr is fixed; ignore this loss
            loss_avg_ocr.add(ocrCost_train)
        else:
            loss_avg_ocr.add(torch.tensor(0.0))

        #Domain discriminator: Dis update
        disModel.zero_grad()
        disCost = opt.disWeight * 0.5 * (
            disModel.module.calc_dis_loss(images_recon_1.detach(), image_real)
            + disModel.module.calc_dis_loss(images_recon_2.detach(), image))
        disCost.backward()
        # torch.nn.utils.clip_grad_norm_(disModel.parameters(), opt.grad_clip)  # gradient clipping with 5 (Default)
        dis_optimizer.step()
        loss_avg_dis.add(disCost)

        # #[Style Encoder] + [Word Generator] update
        #Adversarial loss
        disGenCost = 0.5 * (disModel.module.calc_gen_loss(images_recon_1) +
                            disModel.module.calc_gen_loss(images_recon_2))

        #Input reconstruction loss
        recCost = recCriterion(images_recon_1, image)

        #Pair style reconstruction loss
        if opt.styleReconWeight == 0.0:
            styleRecCost = torch.tensor(0.0)
        else:
            if opt.styleDetach:
                styleRecCost = styleRecCriterion(
                    model(images_recon_2, None, None, styleFlag=True),
                    style.detach())
            else:
                styleRecCost = styleRecCriterion(
                    model(images_recon_2, None, None, styleFlag=True), style)

        #OCR Content cost
        ocrCost = 0.5 * (ocrCost_1 + ocrCost_2)

        cost = opt.ocrWeight * ocrCost + opt.reconWeight * recCost + opt.disWeight * disGenCost + opt.styleReconWeight * styleRecCost

        model.zero_grad()
        ocrModel.zero_grad()
        disModel.zero_grad()
        cost.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()
        loss_avg.add(cost)

        #Individual losses
        loss_avg_ocrRecon_1.add(opt.ocrWeight * 0.5 * ocrCost_1)
        loss_avg_ocrRecon_2.add(opt.ocrWeight * 0.5 * ocrCost_2)
        loss_avg_gen.add(opt.disWeight * disGenCost)
        loss_avg_imgRecon.add(opt.reconWeight * recCost)
        loss_avg_styRecon.add(opt.styleReconWeight * styleRecCost)

        # validation part
        if (
                iteration + 1
        ) % opt.valInterval == 0 or iteration == 0:  # To see training progress, we also conduct validation when 'iteration == 0'

            #Save training images
            os.makedirs(os.path.join(opt.exp_dir, opt.exp_name, 'trainImages',
                                     str(iteration)),
                        exist_ok=True)
            for trImgCntr in range(batch_size):
                try:
                    save_image(
                        tensor2im(image[trImgCntr].detach()),
                        os.path.join(
                            opt.exp_dir, opt.exp_name, 'trainImages',
                            str(iteration),
                            str(trImgCntr) + '_input_' + labels_gt[trImgCntr] +
                            '.png'))
                    save_image(
                        tensor2im(images_recon_1[trImgCntr].detach()),
                        os.path.join(
                            opt.exp_dir, opt.exp_name, 'trainImages',
                            str(iteration),
                            str(trImgCntr) + '_recon_' + labels_1[trImgCntr] +
                            '.png'))
                    save_image(
                        tensor2im(images_recon_2[trImgCntr].detach()),
                        os.path.join(
                            opt.exp_dir, opt.exp_name, 'trainImages',
                            str(iteration),
                            str(trImgCntr) + '_pair_' + labels_2[trImgCntr] +
                            '.png'))
                except:
                    print('Warning while saving training image')

            elapsed_time = time.time() - start_time
            # for log

            with open(os.path.join(opt.exp_dir, opt.exp_name, 'log_train.txt'),
                      'a') as log:
                model.eval()
                ocrModel.module.Transformation.eval()
                ocrModel.module.FeatureExtraction.eval()
                ocrModel.module.AdaptiveAvgPool.eval()
                ocrModel.module.SequenceModeling.eval()
                ocrModel.module.Prediction.eval()
                disModel.eval()

                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation_synth_lrw_res(
                        iteration, model, ocrModel, disModel, recCriterion,
                        styleRecCriterion, ocrCriterion, valid_loader,
                        converter, opt)
                model.train()
                if not opt.ocrFixed:
                    ocrModel.train()
                else:
                    #     ocrModel.module.Transformation.eval()
                    #     ocrModel.module.FeatureExtraction.eval()
                    #     ocrModel.module.AdaptiveAvgPool.eval()
                    ocrModel.module.SequenceModeling.train()
                #     ocrModel.module.Prediction.eval()

                disModel.train()

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}] Train OCR loss: {loss_avg_ocr.val():0.5f}, Train Synth loss: {loss_avg.val():0.5f}, Train Dis loss: {loss_avg_dis.val():0.5f}, Valid OCR loss: {valid_loss[0]:0.5f}, Valid Synth loss: {valid_loss[1]:0.5f}, Valid Dis loss: {valid_loss[2]:0.5f}, Elapsed_time: {elapsed_time:0.5f}'

                current_model_log_ocr = f'{"Current_accuracy_OCR":17s}: {current_accuracy[0]:0.3f}, {"Current_norm_ED_OCR":17s}: {current_norm_ED[0]:0.2f}'
                current_model_log_1 = f'{"Current_accuracy_recon":17s}: {current_accuracy[1]:0.3f}, {"Current_norm_ED_recon":17s}: {current_norm_ED[1]:0.2f}'
                current_model_log_2 = f'{"Current_accuracy_pair":17s}: {current_accuracy[2]:0.3f}, {"Current_norm_ED_pair":17s}: {current_norm_ED[2]:0.2f}'

                #plotting
                lib.plot.plot(os.path.join(plotDir, 'Train-OCR-Loss'),
                              loss_avg_ocr.val().item())
                lib.plot.plot(os.path.join(plotDir, 'Train-Synth-Loss'),
                              loss_avg.val().item())
                lib.plot.plot(os.path.join(plotDir, 'Train-Dis-Loss'),
                              loss_avg_dis.val().item())

                lib.plot.plot(os.path.join(plotDir, 'Train-OCR-Recon1-Loss'),
                              loss_avg_ocrRecon_1.val().item())
                lib.plot.plot(os.path.join(plotDir, 'Train-OCR-Recon2-Loss'),
                              loss_avg_ocrRecon_2.val().item())
                lib.plot.plot(os.path.join(plotDir, 'Train-Gen-Loss'),
                              loss_avg_gen.val().item())
                lib.plot.plot(os.path.join(plotDir, 'Train-ImgRecon1-Loss'),
                              loss_avg_imgRecon.val().item())
                lib.plot.plot(os.path.join(plotDir, 'Train-StyRecon2-Loss'),
                              loss_avg_styRecon.val().item())

                lib.plot.plot(os.path.join(plotDir, 'Valid-OCR-Loss'),
                              valid_loss[0].item())
                lib.plot.plot(os.path.join(plotDir, 'Valid-Synth-Loss'),
                              valid_loss[1].item())
                lib.plot.plot(os.path.join(plotDir, 'Valid-Dis-Loss'),
                              valid_loss[2].item())

                lib.plot.plot(os.path.join(plotDir, 'Valid-OCR-Recon1-Loss'),
                              valid_loss[3].item())
                lib.plot.plot(os.path.join(plotDir, 'Valid-OCR-Recon2-Loss'),
                              valid_loss[4].item())
                lib.plot.plot(os.path.join(plotDir, 'Valid-Gen-Loss'),
                              valid_loss[5].item())
                lib.plot.plot(os.path.join(plotDir, 'Valid-ImgRecon1-Loss'),
                              valid_loss[6].item())
                lib.plot.plot(os.path.join(plotDir, 'Valid-StyRecon2-Loss'),
                              valid_loss[7].item())

                lib.plot.plot(os.path.join(plotDir, 'Orig-OCR-WordAccuracy'),
                              current_accuracy[0])
                lib.plot.plot(os.path.join(plotDir, 'Recon-OCR-WordAccuracy'),
                              current_accuracy[1])
                lib.plot.plot(os.path.join(plotDir, 'Pair-OCR-WordAccuracy'),
                              current_accuracy[2])

                lib.plot.plot(os.path.join(plotDir, 'Orig-OCR-CharAccuracy'),
                              current_norm_ED[0])
                lib.plot.plot(os.path.join(plotDir, 'Recon-OCR-CharAccuracy'),
                              current_norm_ED[1])
                lib.plot.plot(os.path.join(plotDir, 'Pair-OCR-CharAccuracy'),
                              current_norm_ED[2])

                # keep best accuracy model (on valid dataset)
                if current_accuracy[1] > best_accuracy:
                    best_accuracy = current_accuracy[1]
                    torch.save(
                        model.state_dict(),
                        os.path.join(opt.exp_dir, opt.exp_name,
                                     'best_accuracy.pth'))
                    torch.save(
                        disModel.state_dict(),
                        os.path.join(opt.exp_dir, opt.exp_name,
                                     'best_accuracy_dis.pth'))
                if current_norm_ED[1] > best_norm_ED:
                    best_norm_ED = current_norm_ED[1]
                    torch.save(
                        model.state_dict(),
                        os.path.join(opt.exp_dir, opt.exp_name,
                                     'best_norm_ED.pth'))
                    torch.save(
                        disModel.state_dict(),
                        os.path.join(opt.exp_dir, opt.exp_name,
                                     'best_norm_ED_dis.pth'))
                best_model_log = f'{"Best_accuracy_Recon":17s}: {best_accuracy:0.3f}, {"Best_norm_ED_Recon":17s}: {best_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy[0] > best_accuracy_ocr:
                    best_accuracy_ocr = current_accuracy[0]
                    if not opt.ocrFixed:
                        torch.save(
                            ocrModel.state_dict(),
                            os.path.join(opt.exp_dir, opt.exp_name,
                                         'best_accuracy_ocr.pth'))
                if current_norm_ED[0] > best_norm_ED_ocr:
                    best_norm_ED_ocr = current_norm_ED[0]
                    if not opt.ocrFixed:
                        torch.save(
                            ocrModel.state_dict(),
                            os.path.join(opt.exp_dir, opt.exp_name,
                                         'best_norm_ED_ocr.pth'))
                best_model_log_ocr = f'{"Best_accuracy_ocr":17s}: {best_accuracy_ocr:0.3f}, {"Best_norm_ED_ocr":17s}: {best_norm_ED_ocr:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log_ocr}\n{current_model_log_1}\n{current_model_log_2}\n{best_model_log_ocr}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":32s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'

                for gt_ocr, pred_ocr, confidence_ocr, gt_1, pred_1, confidence_1, gt_2, pred_2, confidence_2 in zip(
                        labels[0][:5], preds[0][:5], confidence_score[0][:5],
                        labels[1][:5], preds[1][:5], confidence_score[1][:5],
                        labels[2][:5], preds[2][:5], confidence_score[2][:5]):
                    if 'Attn' in opt.Prediction:
                        # gt_ocr = gt_ocr[:gt_ocr.find('[s]')]
                        pred_ocr = pred_ocr[:pred_ocr.find('[s]')]

                        # gt_1 = gt_1[:gt_1.find('[s]')]
                        pred_1 = pred_1[:pred_1.find('[s]')]

                        # gt_2 = gt_2[:gt_2.find('[s]')]
                        pred_2 = pred_2[:pred_2.find('[s]')]

                    predicted_result_log += f'{"ocr"}: {gt_ocr:27s} | {pred_ocr:25s} | {confidence_ocr:0.4f}\t{str(pred_ocr == gt_ocr)}\n'
                    predicted_result_log += f'{"recon"}: {gt_1:25s} | {pred_1:25s} | {confidence_1:0.4f}\t{str(pred_1 == gt_1)}\n'
                    predicted_result_log += f'{"pair"}: {gt_2:26s} | {pred_2:25s} | {confidence_2:0.4f}\t{str(pred_2 == gt_2)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')

                loss_avg_ocr.reset()
                loss_avg.reset()
                loss_avg_dis.reset()

                loss_avg_ocrRecon_1.reset()
                loss_avg_ocrRecon_2.reset()
                loss_avg_gen.reset()
                loss_avg_imgRecon.reset()
                loss_avg_styRecon.reset()

            lib.plot.flush()

        lib.plot.tick()

        # save model per 1e+5 iter.
        if (iteration) % 1e+5 == 0:
            torch.save(
                model.state_dict(),
                os.path.join(opt.exp_dir, opt.exp_name,
                             'iter_' + str(iteration + 1) + '_synth.pth'))
            if not opt.ocrFixed:
                torch.save(
                    ocrModel.state_dict(),
                    os.path.join(opt.exp_dir, opt.exp_name,
                                 'iter_' + str(iteration + 1) + '_ocr.pth'))
            torch.save(
                disModel.state_dict(),
                os.path.join(opt.exp_dir, opt.exp_name,
                             'iter_' + str(iteration + 1) + '_dis.pth'))

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
Ejemplo n.º 12
0
def demo(opt):
    """ model configuration """
    if 'Transformer' in opt.SequenceModeling:
        converter = TransformerLabelConverter(opt.character)
    elif 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)

    # load model
    if opt.saved_model != '':
        print('loading pretrained model from %s' % opt.saved_model)
        checkpoint = torch.load(opt.saved_model)
        if type(checkpoint) == dict:
            model.load_state_dict(checkpoint['state_dict'])
        else:
            model.load_state_dict(checkpoint)
        del checkpoint
        torch.cuda.empty_cache()

    model = torch.nn.DataParallel(model)
    if torch.cuda.is_available():
        model = model.cuda()

    # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
    AlignCollate_demo = AlignCollate(imgH=opt.imgH,
                                     imgW=opt.imgW,
                                     keep_ratio_with_pad=opt.PAD)
    demo_data = RawDataset(root=opt.image_folder, opt=opt)  # use RawDataset
    demo_loader = torch.utils.data.DataLoader(demo_data,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=int(opt.workers),
                                              collate_fn=AlignCollate_demo,
                                              pin_memory=True)

    # predict
    model.eval()
    dict_gt = {}
    with open('gt.txt', 'r') as gt_file:
        gt = gt_file.readlines()
        for line in gt:
            key = line.split(', "')[0]
            value = line.split(', "')[1].replace('"\n', '').lower()
            dict_gt[key] = value
    for image_tensors, image_path_list in demo_loader:
        batch_size = image_tensors.size(0)
        with torch.no_grad():
            image = image_tensors.cuda()
            # For max length prediction
            length_for_pred = torch.cuda.IntTensor([opt.batch_max_length] *
                                                   batch_size)
            text_for_pred = torch.cuda.LongTensor(
                batch_size, opt.batch_max_length + 1).fill_(0)
        if 'Transformer' in opt.SequenceModeling:
            preds = model(image, text_for_pred, is_train=False)
            # select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            preds_str = converter.decode(preds_index, length_for_pred)

        elif 'CTC' in opt.Prediction:
            preds = model(image, text_for_pred).log_softmax(2)

            # Select max probabilty (greedy decoding) then decode index to character
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            _, preds_index = preds.permute(1, 0, 2).max(2)
            preds_index = preds_index.transpose(1, 0).contiguous().view(-1)
            preds_str = converter.decode(preds_index.data, preds_size.data)

        else:
            preds = model(image, text_for_pred, is_train=False)

            # select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            preds_str = converter.decode(preds_index, length_for_pred)

        print('-' * 80)
        print('image_path\tpredicted_labels')
        print('-' * 80)
        for img_name, pred in zip(image_path_list, preds_str):
            if 'Transformer' in opt.SequenceModeling:
                pred = pred[:pred.find('</s>')]
            elif 'Attn' in opt.Prediction:
                # prune after "end of sentence" token ([s])
                pred = pred[:pred.find('[s]')]
            raw_img = cv2.imread(img_name)
            raw_img = cv2.resize(raw_img, (200, 64))
            tmp_img = np.zeros((128, 200, 3), np.uint8)
            tmp_img.fill(255)
            tmp_img[:64, :200] = raw_img
            raw_img = tmp_img
            font = cv2.FONT_HERSHEY_SIMPLEX
            bottomLeftCornerOfText = (5, 90)
            fontScale = 1
            fontColor = (0, 0, 255)
            lineType = 2
            if pred == dict_gt[img_name.split('/')[-1]]:
                cv2.putText(raw_img, pred, (5, 90), font, fontScale,
                            (0, 255, 0), lineType)
                raw_img = raw_img[:96, :200]
                cv2.imwrite('./trash/true/' + img_name.split('/')[-1], raw_img)
            else:
                cv2.putText(raw_img, pred, (5, 90), font, fontScale,
                            (0, 0, 255), lineType)
                cv2.putText(raw_img, dict_gt[img_name.split('/')[-1]],
                            (5, 125), font, fontScale, (0, 255, 0), lineType)
                cv2.imwrite('./trash/false/' + img_name.split('/')[-1],
                            raw_img)
            print(f'{img_name}\t{pred}')
        min_point = points_sorted[i][0]
        max_point = points_sorted[i][1]
        #print("Cropped image")
        mask_file = result_folder + filename + "_" + str(
            order_sorted[i]) + "_" + str(i) + '.jpg'
        #print(mask_file)
        crop_image = rgb_img[int(min_point[1]):int(max_point[1]),
                             int(min_point[0]):int(max_point[0])]
        #plt.imshow(crop_image)
        #plt.show()
        cv2.imwrite(mask_file, crop_image)

    # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
    #result_folder = './intermediate_result/'
    AlignCollate_demo = AlignCollate(imgH=opt.imgH,
                                     imgW=opt.imgW,
                                     keep_ratio_with_pad=opt.PAD)
    demo_data = RawDataset(root=result_folder, opt=opt)  # use RawDataset
    demo_loader = torch.utils.data.DataLoader(demo_data,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=int(opt.workers),
                                              collate_fn=AlignCollate_demo,
                                              pin_memory=True)
    print("Starting text classification")
    model.eval()
    with torch.no_grad():
        for image_tensors, image_path_list in demo_loader:
            batch_size = image_tensors.size(0)
            image = image_tensors.to(device)
            #image = (torch.from_numpy(crop_image).unsqueeze(0)).to(device)
Ejemplo n.º 14
0
def train(opt):
    """ dataset preparation """
    if not opt.data_filtering_off:
        print('Filtering the images containing characters which are not in opt.character')
        print('Filtering the images whose label is longer than opt.batch_max_length')
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    log = open('./saved_models/{}/log_dataset.txt'.format(opt.experiment_name), 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=opt.batch_size,
        shuffle=True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    
    """部分参数初始化"""
    learning_rate = 1e-4
    label2num, num2label = label_num('all_labels.txt')
    num_classes = len(label2num)
    print('训练类别数:{}'.format(num_classes))
    print('训练集标签列表:\n{}'.format(num2label.values()))
    print('-' * 80)

    class VGGNet(nn.Module):
        def __init__(self, num_classes=num_classes):
            super(VGGNet, self).__init__()
            net = models.vgg16(pretrained=True)
            net.classifier = nn.Sequential()
            self.features = net
            self.classifier = nn.Sequential(
                    nn.Linear(512 * 7 * 7, 512),
                    nn.ReLU(True),
                    nn.Dropout(),
                    nn.Linear(512, 128),
                    nn.ReLU(True),
                    nn.Dropout(),
                    nn.Linear(128, num_classes),
            )

        def forward(self, x):
            x = self.features(x)
            x = x.view(x.size(0), -1)
            x = self.classifier(x)
            return x

    #--------------------训练过程---------------------------------
    model = VGGNet()
    if torch.cuda.is_available():
        model.cuda()
    params = [{'params': md.parameters()} for md in model.children()
              if md in [model.classifier]]
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    loss_func = nn.CrossEntropyLoss()

    Loss_list = []
    Accuracy_list = []


    """ start training """
    start_iter = 0
    if opt.saved_model != '':
        try:
            start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
            print('continue to train, start_iter: {}'.format(start_iter))
        except:
            pass

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    i = start_iter
    num2label = opt.num2label
    while(True):
        # train part
        # training-----------------------------
        image_tensors, labels = train_dataset.get_batch()
        batch_x = image_tensors.to(device)
        #labels = [num2label[x] for x in labels]#将汉字转换回标签
        batch_y = torch.from_numpy(np.asarray(labels, dtype=np.int8)).to(device)
        train_loss = 0.
        train_acc = 0.

        out = model(batch_x)
        loss = loss_func(out, batch_y.long())
        train_loss += loss.item()
        pred = torch.max(out, 1)[1]
        train_correct = (pred == batch_y).sum()
        train_acc += train_correct.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if (i + 1) % 0.5e+2 == 0:
            print('Step{}:'.format(i + 1))
            print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
                labels)), train_acc / (len(labels))))
        # save model per 1e+5 iter.
        if (i + 1) % 5e+2 == 0:
            torch.save(
                model.state_dict(), './saved_models/{}/iter_{}.pth'.format(opt.experiment_name, i+1))

        if i == opt.num_iter:
            torch.save(
                model.state_dict(), './saved_models/{}/iter_{}.pth'.format(opt.experiment_name, i+1))
            print('end the training')
            break
        i += 1
        
        # evaluation--------------------------------
        if i % opt.valInterval == 0:
            elapsed_time = time.time() - start_time
            # for log
            model.eval()
            eval_loss = 0.
            eval_acc = 0.
            length_of_data = 0
            for image_tensors, labels in valid_loader:
                batch_x = image_tensors.to(device)
                batch_y = torch.from_numpy(np.asarray(labels, dtype=np.int8)).to(device)
                length_of_data += len(labels)
                #batch_x, batch_y = Variable(batch_x, volatile=True).cuda(), Variable(batch_y, volatile=True).cuda()
                out = model(batch_x)
                loss = loss_func(out, batch_y.long())
                eval_loss += loss.item()
                pred = torch.max(out, 1)[1]
                num_correct = (pred == batch_y).sum()
                eval_acc += num_correct.item()
            print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (length_of_data), eval_acc / (length_of_data)))
                
            Loss_list.append(eval_loss / (len(labels)))
            Accuracy_list.append(100 * eval_acc / (len(labels)))
        
    x1 = np.arange(0, 100).reshape(1,-1)
    x2 = np.arange(0, 100).reshape(1,-1)
    y1 = np.array(Accuracy_list).reshape(1,-1)
    y2 = np.array(Loss_list).reshape(1,-1)
    plt.figure()
    plt.subplot(2, 1, 1)
    plt.plot(x1, y1, 'o-')
    plt.title('Test accuracy vs. epoches')
    plt.ylabel('Test accuracy')
    plt.subplot(2, 1, 2)
    plt.plot(x2, y2, '.-')
    plt.xlabel('Test loss vs. epoches')
    plt.ylabel('Test loss')
    plt.show()
    plt.savefig("accuracy_loss.jpg")
    sys.exit()
Ejemplo n.º 15
0
def demoToTxt2(image_folder, saved_model, txtFile):  # sensitive
    parser = argparse.ArgumentParser()
    parser.add_argument('--image_folder',
                        default=image_folder,
                        help='path to image_folder which contains text images')
    parser.add_argument('--workers',
                        type=int,
                        help='number of data loading workers',
                        default=4)
    parser.add_argument('--batch_size',
                        type=int,
                        default=100,
                        help='input batch size')
    parser.add_argument('--saved_model',
                        default=saved_model,
                        help="path to saved_model to evaluation")
    """ Data processing """
    parser.add_argument('--batch_max_length',
                        type=int,
                        default=20,
                        help='maximum-label-length')
    parser.add_argument('--imgH',
                        type=int,
                        default=32,
                        help='the height of the input image')
    parser.add_argument('--imgW',
                        type=int,
                        default=100,
                        help='the width of the input image')
    parser.add_argument('--rgb', action='store_true', help='use rgb input')
    parser.add_argument('--character',
                        type=str,
                        default='0123456789',
                        help='character label')
    parser.add_argument('--sensitive',
                        default=True,
                        help='for sensitive character mode')
    parser.add_argument('--PAD',
                        default=False,
                        action='store_true',
                        help='whether to keep ratio then pad for image resize')
    """ Model Architecture """
    parser.add_argument('--Transformation',
                        default='TPS',
                        type=str,
                        help='Transformation stage. None|TPS')
    parser.add_argument('--FeatureExtraction',
                        default='ResNet',
                        type=str,
                        help='FeatureExtraction stage. VGG|RCNN|ResNet')
    parser.add_argument('--SequenceModeling',
                        default='BiLSTM',
                        type=str,
                        help='SequenceModeling stage. None|BiLSTM')
    parser.add_argument('--Prediction',
                        default='Attn',
                        type=str,
                        help='Prediction stage. CTC|Attn')
    parser.add_argument('--num_fiducial',
                        type=int,
                        default=20,
                        help='number of fiducial points of TPS-STN')
    parser.add_argument(
        '--input_channel',
        type=int,
        default=1,
        help='the number of input channel of Feature extractor')
    parser.add_argument(
        '--output_channel',
        type=int,
        default=512,
        help='the number of output channel of Feature extractor')
    parser.add_argument('--hidden_size',
                        type=int,
                        default=256,
                        help='the size of the LSTM hidden state')

    opt = parser.parse_args()
    """ vocab / character number configuration """
    if opt.sensitive:
        opt.character += 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
        # opt.character = string.printable[:-6]  # same with ASTER setting (use 94 char).

    cudnn.benchmark = True
    cudnn.deterministic = True
    opt.num_gpu = torch.cuda.device_count()
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)

    model = torch.nn.DataParallel(model)
    if torch.cuda.is_available():
        model = model.cuda()

    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model))

    # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
    AlignCollate_demo = AlignCollate(imgH=opt.imgH,
                                     imgW=opt.imgW,
                                     keep_ratio_with_pad=opt.PAD)
    demo_data = RawDataset(root=opt.image_folder, opt=opt)  # use RawDataset
    demo_loader = torch.utils.data.DataLoader(demo_data,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=int(opt.workers),
                                              collate_fn=AlignCollate_demo,
                                              pin_memory=True)

    # predict
    model.eval()
    saved_file = open(txtFile, 'w')
    for image_tensors, image_path_list in demo_loader:
        batch_size = image_tensors.size(0)
        with torch.no_grad():
            image = image_tensors.cuda()
            # For max length prediction
            length_for_pred = torch.cuda.IntTensor([opt.batch_max_length] *
                                                   batch_size)
            text_for_pred = torch.cuda.LongTensor(
                batch_size, opt.batch_max_length + 1).fill_(0)

        if 'CTC' in opt.Prediction:
            preds = model(image, text_for_pred).log_softmax(2)

            # Select max probabilty (greedy decoding) then decode index to character
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            _, preds_index = preds.permute(1, 0, 2).max(2)
            preds_index = preds_index.transpose(1, 0).contiguous().view(-1)
            preds_str = converter.decode(preds_index.data, preds_size.data)

        else:
            preds = model(image, text_for_pred, is_train=False)

            # select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            preds_str = converter.decode(preds_index, length_for_pred)

        print('-' * 80)
        print('image_path\tpredicted_labels')
        print('-' * 80)

        for img_name, pred in zip(image_path_list, preds_str):
            if 'Attn' in opt.Prediction:
                pred = pred[:pred.find(
                    '[s]')]  # prune after "end of sentence" token ([s])
            print(f'{img_name}\t{pred}')
            saved_file.write(f'{img_name}\t{pred}\n')
Ejemplo n.º 16
0
def runDeepTextNet(segmentedImagesList):
    opt = argparse.Namespace(FeatureExtraction='ResNet',
                             PAD=False,
                             Prediction='Attn',
                             SequenceModeling='BiLSTM',
                             Transformation='TPS',
                             batch_max_length=25,
                             batch_size=192,
                             character='0123456789abcdefghijklmnopqrstuvwxyz',
                             hidden_size=256,
                             image_folder='demo_image/',
                             imgH=32,
                             imgW=100,
                             input_channel=1,
                             num_class=38,
                             num_fiducial=20,
                             num_gpu=0,
                             output_channel=512,
                             rgb=False,
                             saved_model='TPS-ResNet-BiLSTM-Attn.pth',
                             sensitive=False,
                             workers=4)

    model = Model(opt)
    model = torch.nn.DataParallel(model).to('cpu')
    directory = "TPS-ResNet-BiLSTM-Attn.pth"
    model.load_state_dict(torch.load(directory, map_location='cpu'))

    converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)
    if opt.rgb:
        opt.input_channel = 3

    AlignCollate_demo = AlignCollate(imgH=opt.imgH,
                                     imgW=opt.imgW,
                                     keep_ratio_with_pad=opt.PAD)
    demo_data = RawDataset(root=segmentedImagesList, opt=opt)  # use RawDataset
    demo_loader = torch.utils.data.DataLoader(demo_data,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=int(opt.workers),
                                              collate_fn=AlignCollate_demo,
                                              pin_memory=True)

    # predict
    model.eval()

    out_preds_texts = []
    for image_tensors, image_path_list in demo_loader:
        batch_size = image_tensors.size(0)
        image = image_tensors.to(device)
        # For max length prediction
        length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                          batch_size).to(device)
        text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                         1).fill_(0).to(device)
        preds = model(image, text_for_pred, is_train=False)
        # select max probabilty (greedy decoding) then decode index to character
        _, preds_index = preds.max(2)
        preds_str = converter.decode(preds_index, length_for_pred)
        preds_prob = F.softmax(preds, dim=2)
        preds_max_prob, _ = preds_prob.max(dim=2)
        for img_name, pred, pred_max_prob in zip(image_path_list, preds_str,
                                                 preds_max_prob):
            if 'Attn' in opt.Prediction:
                pred_EOS = pred.find('[s]')
                pred = pred[:
                            pred_EOS]  # prune after "end of sentence" token ([s])
                pred_max_prob = pred_max_prob[:pred_EOS]

            # calculate confidence score (= multiply of pred_max_prob)
            confidence_score = pred_max_prob.cumprod(dim=0)[-1]
            # print(pred)
            out_preds_texts.append(pred)
    # print(out_preds_texts)

    sentence_out = [' '.join(out_preds_texts)]
    return (sentence_out)
Ejemplo n.º 17
0
def demo(opt):
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)
    model = torch.nn.DataParallel(model).to(device)

    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model, map_location=device))

    # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
    AlignCollate_demo = AlignCollate(imgH=opt.imgH,
                                     imgW=opt.imgW,
                                     keep_ratio_with_pad=opt.PAD)
    demo_data = RawDataset(root=opt.image_folder, opt=opt)  # use RawDataset
    demo_loader = torch.utils.data.DataLoader(demo_data,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=int(opt.workers),
                                              collate_fn=AlignCollate_demo,
                                              pin_memory=True)

    #modified
    #     result_df = pd.DataFrame(columns=['video_id', 'word_id' ,'ocr_text'])
    result_df = pd.read_csv(opt.out_df_path)
    # result_df.append({'video_id':1, 'word_id': 1, 'ocr_text':"sdf"}, ignore_index=True)
    ###################

    # predict
    model.eval()
    with torch.no_grad():
        for image_tensors, image_path_list in demo_loader:
            batch_size = image_tensors.size(0)
            image = image_tensors.to(device)
            # For max length prediction
            length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                              batch_size).to(device)
            text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                             1).fill_(0).to(device)

            if 'CTC' in opt.Prediction:
                preds = model(image, text_for_pred)

                # Select max probabilty (greedy decoding) then decode index to character
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                _, preds_index = preds.max(2)
                # preds_index = preds_index.view(-1)
                preds_str = converter.decode(preds_index, preds_size)

            else:
                preds = model(image, text_for_pred, is_train=False)

                # select max probabilty (greedy decoding) then decode index to character
                _, preds_index = preds.max(2)
                preds_str = converter.decode(preds_index, length_for_pred)

            # log = open(f'./log_demo_result.txt', 'a')
            # dashed_line = '-' * 80
            # head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score'

            # print(f'{dashed_line}\n{head}\n{dashed_line}')
            # log.write(f'{dashed_line}\n{head}\n{dashed_line}\n')

            preds_prob = F.softmax(preds, dim=2)
            preds_max_prob, _ = preds_prob.max(dim=2)
            for img_name, pred, pred_max_prob in zip(image_path_list,
                                                     preds_str,
                                                     preds_max_prob):
                if 'Attn' in opt.Prediction:
                    pred_EOS = pred.find('[s]')
                    pred = pred[:
                                pred_EOS]  # prune after "end of sentence" token ([s])
                    pred_max_prob = pred_max_prob[:pred_EOS]

                # # calculate confidence score (= multiply of pred_max_prob)
                # confidence_score = pred_max_prob.cumprod(dim=0)[-1]

                # print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}')
                # log.write(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n')

                #modified
                img_name_path = Path(img_name)
                #                 result_df = result_df.append({'video_id': img_name_path.parent.stem , 'word_id': img_name_path.stem , 'ocr_text': pred}, ignore_index=True)
                result_df.loc[(result_df.video_id == img_name_path.parent.stem)
                              & (result_df.word_id == int(img_name_path.stem)),
                              'ocr_text'] = pred
                ##################

            # log.close()

        result_df.to_csv(opt.out_df_path, index=False)
Ejemplo n.º 18
0
def index():
    model, converter, length_for_pred, text_for_pred, opt = loader()
    start_time = time.time()

    AlignCollate_demo = AlignCollate(imgH=opt['imgH'],
                                     imgW=opt['imgW'],
                                     keep_ratio_with_pad=opt['PAD'])
    demo_data = RawDataset(root=opt['image_folder'], opt=opt)  # use RawDataset
    demo_loader = torch.utils.data.DataLoader(demo_data,
                                              batch_size=opt['batch_size'],
                                              shuffle=False,
                                              num_workers=int(opt['workers']),
                                              collate_fn=AlignCollate_demo,
                                              pin_memory=True)

    get_data = time.time() - start_time

    # predict
    with torch.no_grad():
        for image_tensors, image_path_list in demo_loader:
            batch_size = image_tensors.size(0)
            image = image_tensors.to(device)
            # 最大長予測用
            # torch.cuda.synchronize(device)
            if 'CTC' in opt['Prediction']:
                preds = model(image, text_for_pred)  #.log_softmax(2)
                preds = preds.log_softmax(2)
                # 最大確率を選択し、インデックスを文字にデコードします
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                _, preds_index = preds.max(2)
                preds_index = preds_index.view(-1)
                preds_str = converter.decode(preds_index.data, preds_size.data)

            else:
                preds = model(image, text_for_pred, is_train=False)

                # 最大確率を選択し、インデックスを文字にデコードします
                _, preds_index = preds.max(2)
                preds_str = converter.decode(preds_index, length_for_pred)

            print('-' * 80)
            print('image_path\tpredicted_labels')
            print('-' * 80)
            for img_name, pred in zip(image_path_list, preds_str):
                if 'Attn' in opt['Prediction']:
                    pred = pred[:pred.find('[s]')]  # 文の終わりトークン([s])の後の剪定

                print(f'{img_name}\t{pred}')

        forward_time = time.time() - start_time
        only_infer_time = forward_time - get_data

        print('*' * 80)
        print('get_dta_time:{:.5f}[sec]'.format(get_data))
        print('only_infer_time:{:.5f}[sec]'.format(only_infer_time))
        print('total_time:{:.5f}[sec]'.format(forward_time))
        print('*' * 80)

        img_name = [i[9:] for i in image_path_list]
        items = {}
        for path, pred in zip(img_name, preds_str):
            items[path] = pred

    return render_template('index.html', images=items)
Ejemplo n.º 19
0
def demo(opt):
    """Open csv file wherein you are going to write the Predicted Words"""
    data = pd.read_csv('/content/data.csv')
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)
    model = torch.nn.DataParallel(model).to(device)

    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model, map_location=device))

    # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
    AlignCollate_demo = AlignCollate(imgH=opt.imgH,
                                     imgW=opt.imgW,
                                     keep_ratio_with_pad=opt.PAD)
    demo_data = RawDataset(root=opt.image_folder, opt=opt)  # use RawDataset
    demo_loader = torch.utils.data.DataLoader(demo_data,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=int(opt.workers),
                                              collate_fn=AlignCollate_demo,
                                              pin_memory=True)

    # predict
    model.eval()
    with torch.no_grad():
        for image_tensors, image_path_list in demo_loader:
            batch_size = image_tensors.size(0)
            image = image_tensors.to(device)
            # For max length prediction
            length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                              batch_size).to(device)
            text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                             1).fill_(0).to(device)

            if 'CTC' in opt.Prediction:
                preds = model(image, text_for_pred)

                # Select max probabilty (greedy decoding) then decode index to character
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                _, preds_index = preds.max(2)
                # preds_index = preds_index.view(-1)
                preds_str = converter.decode(preds_index.data, preds_size.data)

            else:
                preds = model(image, text_for_pred, is_train=False)

                # select max probabilty (greedy decoding) then decode index to character
                _, preds_index = preds.max(2)
                preds_str = converter.decode(preds_index, length_for_pred)

            dashed_line = '-' * 80
            # head = f'{"image_path":25s}\t {"predicted_labels":25s}\t confidence score'

            # print(f'{dashed_line}\n{head}\n{dashed_line}')
            # log.write(f'{dashed_line}\n{head}\n{dashed_line}\n')

            preds_prob = F.softmax(preds, dim=2)
            preds_max_prob, _ = preds_prob.max(dim=2)
            for img_name, pred, pred_max_prob in zip(image_path_list,
                                                     preds_str,
                                                     preds_max_prob):

                start = '/content/Result/Crop Words/'
                path = os.path.relpath(img_name, start)

                folder = os.path.dirname(path)

                image_name = os.path.basename(path)

                file_name = '_'.join(image_name.split('_')[:-8])

                txt_file = os.path.join(start, folder, file_name)

                log = open(f'{txt_file}_log_demo_result_vgg.txt', 'a')
                if 'Attn' in opt.Prediction:
                    pred_EOS = pred.find('[s]')
                    pred = pred[:
                                pred_EOS]  # prune after "end of sentence" token ([s])
                    pred_max_prob = pred_max_prob[:pred_EOS]

                # calculate confidence score (= multiply of pred_max_prob)
                confidence_score = pred_max_prob.cumprod(dim=0)[-1]
                imgcropped = cv2.imread(img_name)
                cv2_imshow(imgcropped)
                print(f'{pred:25s}\t {confidence_score:0.4f}\n')
                # print(f'{image_name:25s}\t {pred:25s}\t {confidence_score:0.4f}')
                log.write(
                    f'{image_name:25s}\t {pred:25s}\t {confidence_score:0.4f}\n'
                )

            log.close()
def demo(opt):
    """ model configuration """
    if opt.guide_training :
      from model_guide import Model
    else :
      from model import Model
    if opt.baiduCTC:
        converter = CTCLabelConverterForBaiduWarpctc(opt.character)
    else :
        converter = CTCLabelConverter(opt.character)
    if opt.Prediction == 'Attn' :
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)
    opt.num_class_ctc = opt.num_class
    opt.num_class_attn = opt.num_class_ctc + 1

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
          opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
          opt.SequenceModeling, opt.Prediction)
    model = torch.nn.DataParallel(model).to(device)

    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model, map_location=device), strict = False)

    # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
    AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
    demo_data = RawDataset(root=opt.image_folder, opt=opt)  # use RawDataset
    demo_loader = torch.utils.data.DataLoader(
        demo_data, batch_size=opt.batch_size,
        shuffle=False,
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_demo, pin_memory=True)

    # predict
    model.eval()
    data = pd.DataFrame()
    with torch.no_grad():
        ind = 0
        for image_tensors, image_path_list in demo_loader:
            batch_size = image_tensors.size(0)
            image = image_tensors.to(device)
            # For max length prediction
            length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
            text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)

            if 'CTC' in opt.Prediction:
                if opt.guide_training :
                    preds = model.module.inference(image, text_for_pred)
                else :
                    preds = model(image, text_for_pred)

                # Select max probabilty (greedy decoding) then decode index to character
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)

                # Select max probabilty (greedy decoding) then decode index to character
                if opt.baiduCTC:
                    if (opt.beam_search):
                      preds_index = preds
                    else :
                      _, preds_index = preds.max(2)
                      preds_index = preds_index.view(-1)
                else:
                    _, preds_index = preds.max(2)
                preds_str = converter.decode(preds_index.data, preds_size.data,opt.beam_search)
            else:
                preds = model(image, text_for_pred, is_train=False)

                # select max probabilty (greedy decoding) then decode index to character
                _, preds_index = preds.max(2)
                preds_str = converter.decode(preds_index, length_for_pred)


            log = open(f'./log_demo_result.txt', 'a')
            dashed_line = '-' * 80
            head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score'
            
            print(f'{dashed_line}\n{head}\n{dashed_line}')
            log.write(f'{dashed_line}\n{head}\n{dashed_line}\n')

            preds_prob = F.softmax(preds, dim=2)
            preds_max_prob, _ = preds_prob.max(dim=2)
            for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob):
                if 'Attn' in opt.Prediction:
                    pred_EOS = pred.find('[s]')
                    pred = pred[:pred_EOS]  # prune after "end of sentence" token ([s])
                    pred_max_prob = pred_max_prob[:pred_EOS]

                # calculate confidence score (= multiply of pred_max_prob)
                confidence_score = pred_max_prob.cumprod(dim=0)[-1]
                filename = img_name
                label = pred
                conf = round(confidence_score.item(),3)
                img = cv2.imread(filename)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                img_pil = Image.fromarray(img)
                img_buffer = io.BytesIO()
                img_pil.save(img_buffer, format="PNG")
                imgStr = base64.b64encode(img_buffer.getvalue()).decode("utf-8") 

                data.loc[ind, 'img'] = '<img src="data:image/png;base64,{0:s}">'.format(imgStr)
                data.loc[ind, 'id'] = filename
                data.loc[ind, 'label'] = label
                data.loc[ind, 'conf'] = conf
                ind+=1
                print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}')
                log.write(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n')

            log.close()
        html_all = data.to_html(escape=False)
        if opt.is_save :
            text_file = open("result.html", "w") 
            text_file.write(html_all) 
            text_file.close() 
Ejemplo n.º 21
0
def demo(opt):
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    elif 'Bert' in opt.Prediction:
        converter = TransformerConverter(opt.character, opt.batch_max_length)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)
    opt.alphabet_size = len(opt.character) + 2  # +2 for [UNK]+[EOS]

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
          opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
          opt.SequenceModeling, opt.Prediction)

    model = torch.nn.DataParallel(model)
    if torch.cuda.is_available():
        model = model.cuda()

    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model))

    # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
    AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
    demo_data = RawDataset(root=opt.image_folder, opt=opt)  # use RawDataset
    demo_loader = torch.utils.data.DataLoader(
        demo_data, batch_size=opt.batch_size,
        shuffle=False,
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_demo, pin_memory=True)

    # mkdir result
    experiment_name = os.path.join('./result', opt.image_folder.split('/')[-2])
    if not os.path.exists(experiment_name):
        os.makedirs(experiment_name)
    result = {}

    # predict
    model.eval()
    for idx, (image_tensors, image_path_list) in enumerate(demo_loader):
        batch_size = image_tensors.size(0)
        with torch.no_grad():
            image = image_tensors.cuda()
            # For max length prediction
            length_for_pred = torch.cuda.IntTensor([opt.batch_max_length] * batch_size)
            text_for_pred = torch.cuda.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0)

        if 'CTC' in opt.Prediction:
            preds = model(image, text_for_pred).log_softmax(2)

            # Select max probabilty (greedy decoding) then decode index to character
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            _, preds_index = preds.permute(1, 0, 2).max(2)
            preds_index = preds_index.transpose(1, 0).contiguous().view(-1)
            preds_str = converter.decode(preds_index.data, preds_size.data)

        elif 'Bert' in opt.Prediction:
            with torch.no_grad():
                pad_mask = None
                preds = model(image, pad_mask)

                # select max probabilty (greedy decoding) then decode index to character
                _, preds_index = preds[1].max(2)
                length_for_pred = torch.cuda.IntTensor([preds_index.size(-1)] * batch_size)
                preds_str = converter.decode(preds_index, length_for_pred)

        else:
            preds = model(image, text_for_pred, is_train=False)

            # select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            preds_str = converter.decode(preds_index, length_for_pred)

        print(f'{idx}/{len(demo_data) / opt.batch_size}')

        for img_name, pred in zip(image_path_list, preds_str):
            if 'Attn' in opt.Prediction:
                pred = pred[:pred.find('[s]')]  # prune after "end of sentence" token ([s])

            # for show

            # write in json
            name = f'{img_name}'.split('/')[-1].replace('gt', 'res').split('.')[0]
            value = [{"transcription": f'{pred}'}]
            result[name] = value

    with open(f'{experiment_name}/result.json', 'w') as f:
        json.dump(result, f)
        print("writed finish...")
Ejemplo n.º 22
0
def demo(opt):
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)
    print(opt.num_class)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)
    model = torch.nn.DataParallel(model).to(device)

    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model, map_location=device))
    # model.load_state_dict(copy_state_dict(torch.load(opt.saved_model, map_location=device)))

    # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
    AlignCollate_demo = AlignCollate(imgH=opt.imgH,
                                     imgW=opt.imgW,
                                     keep_ratio_with_pad=opt.PAD)
    # demo_data = RawDataset(root=opt.image_folder, opt=opt)  # use RawDataset
    demo_data = LmdbDataset(root=opt.image_folder, opt=opt,
                            mode='Val')  # use RawDataset

    demo_loader = torch.utils.data.DataLoader(demo_data,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=int(opt.workers),
                                              collate_fn=AlignCollate_demo,
                                              pin_memory=True,
                                              drop_last=True)

    log = open(f'./log_demo_result.txt', 'a')
    # predict
    model.eval()
    fail_count, sample_count = 0, 0
    record_count = 1
    with torch.no_grad():
        for image_tensors, image_path_list, original_images, indexes in demo_loader:
            batch_size = image_tensors.size(0)
            image = image_tensors.to(device)
            # For max length prediction
            length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                              batch_size).to(device)
            text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                             1).fill_(0).to(device)

            if 'CTC' in opt.Prediction:
                preds = model(image, text_for_pred)

                # Select max probabilty (greedy decoding) then decode index to character
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                _, preds_index = preds.max(2)
                # preds_index = preds_index.view(-1)
                preds_str = converter.decode(preds_index, preds_size)

            else:
                preds = model(image, text_for_pred, is_train=False)

                # select max probabilty (greedy decoding) then decode index to character
                _, preds_index = preds.max(2)
                preds_str = converter.decode(preds_index, length_for_pred)

            dashed_line = '-' * 80
            # head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score'

            # print(f'{dashed_line}\n{head}\n{dashed_line}')
            # log.write(f'{dashed_line}\n{head}\n{dashed_line}\n')

            preds_prob = F.softmax(preds, dim=2)
            preds_max_prob, _ = preds_prob.max(dim=2)

            for image_tensor, gt, pred, pred_max_prob, original_image, lmdb_key in zip(
                    image_tensors, image_path_list, preds_str, preds_max_prob,
                    original_images, indexes):
                if 'Attn' in opt.Prediction:
                    pred_EOS = pred.find('[s]')
                    pred = pred[:
                                pred_EOS]  # prune after "end of sentence" token ([s])
                    pred_max_prob = pred_max_prob[:pred_EOS]

                if pred_max_prob.shape[0] > 0:
                    # calculate confidence score (= multiply of pred_max_prob)
                    confidence_score = pred_max_prob.cumprod(dim=0)[-1]
                else:
                    confidence_score = 0.0

                compare_gt = "".join(x.upper() for x in gt if x.isalnum())
                compare_pred = "".join(x.upper() for x in pred if x.isalnum())

                if compare_gt != compare_pred:
                    fail_count += 1
                    # print(f'{gt:25s}\t{pred:25s}\tFail\t{confidence_score:0.4f}\t{record_count}\n')
                    im = to_pil_image(image_tensor)
                    try:
                        # im.save(os.path.join('result', f'{lmdb_key}_{compare_pred}_{compare_gt}.jpeg'))
                        original_image.save(
                            os.path.join(
                                'result', 'fail',
                                f'{lmdb_key}_{compare_pred}_{compare_gt}.jpg'))
                    except Exception as e:
                        print(
                            f'Error: {e} {lmdb_key}_{compare_pred}_{compare_gt}'
                        )
                        exit(1)
                else:
                    # print(f'{gt:25s}\t{pred:25s}\tSuccess\t{confidence_score:0.4f}')
                    im = to_pil_image(image_tensor)
                    try:
                        # im.save(os.path.join('result', f'{lmdb_key}_{compare_pred}_{compare_gt}.jpeg'))
                        original_image.save(
                            os.path.join(
                                'result', 'success',
                                f'{lmdb_key}_{compare_pred}_{compare_gt}.jpg'))
                    except Exception as e:
                        print(
                            f'Error: {e} {lmdb_key}_{compare_pred}_{compare_gt}'
                        )
                        exit(1)

                sample_count += 1
                record_count += 1
        log.close()
        print(
            f'total accuracy: {(sample_count-fail_count)/sample_count:.2f} number of sample: {sample_count}'
        )
def train(opt):
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')

    #considering the real images for discriminator
    opt.batch_size = opt.batch_size * 2

    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(os.path.join(opt.exp_dir, opt.exp_name, 'log_dataset.txt'), 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3

    model = AdaINGen(opt)
    ocrModel = Model(opt)
    disModel = MsImageDis(opt)

    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)

    #  weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # Recognizer weight initialization
    for name, param in ocrModel.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # Discriminator weight initialization
    for name, param in disModel.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    ocrModel = torch.nn.DataParallel(ocrModel).to(device)
    ocrModel.train()

    model = torch.nn.DataParallel(model).to(device)
    model.train()

    disModel = torch.nn.DataParallel(disModel).to(device)
    disModel.train()

    #loading pre-trained model
    if opt.saved_ocr_model != '':
        print(f'loading pretrained ocr model from {opt.saved_ocr_model}')
        if opt.FT:
            ocrModel.load_state_dict(torch.load(opt.saved_ocr_model),
                                     strict=False)
        else:
            ocrModel.load_state_dict(torch.load(opt.saved_ocr_model))
    print("OCRModel:")
    print(ocrModel)

    if opt.saved_synth_model != '':
        print(f'loading pretrained synth model from {opt.saved_synth_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_synth_model),
                                  strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_synth_model))
    print("SynthModel:")
    print(model)

    if opt.saved_dis_model != '':
        print(
            f'loading pretrained discriminator model from {opt.saved_dis_model}'
        )
        if opt.FT:
            disModel.load_state_dict(torch.load(opt.saved_dis_model),
                                     strict=False)
        else:
            disModel.load_state_dict(torch.load(opt.saved_dis_model))
    print("DisModel:")
    print(disModel)
    """ setup loss """
    if 'CTC' in opt.Prediction:
        ocrCriterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        ocrCriterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
            device)  # ignore [GO] token = ignore index 0

    recCriterion = torch.nn.L1Loss()
    styleRecCriterion = torch.nn.L1Loss()

    # loss averager
    loss_avg_ocr = Averager()
    loss_avg = Averager()
    loss_avg_dis = Averager()

    ##---------------------------------------##
    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("SynthOptimizer:")
    print(optimizer)

    #filter parameters for OCR training
    ocr_filtered_parameters = []
    ocr_params_num = []
    for p in filter(lambda p: p.requires_grad, ocrModel.parameters()):
        ocr_filtered_parameters.append(p)
        ocr_params_num.append(np.prod(p.size()))
    print('OCR Trainable params num : ', sum(ocr_params_num))

    # setup optimizer
    if opt.adam:
        ocr_optimizer = optim.Adam(ocr_filtered_parameters,
                                   lr=opt.lr,
                                   betas=(opt.beta1, 0.999))
    else:
        ocr_optimizer = optim.Adadelta(ocr_filtered_parameters,
                                       lr=opt.lr,
                                       rho=opt.rho,
                                       eps=opt.eps)
    print("OCROptimizer:")
    print(ocr_optimizer)

    #filter parameters for OCR training
    dis_filtered_parameters = []
    dis_params_num = []
    for p in filter(lambda p: p.requires_grad, disModel.parameters()):
        dis_filtered_parameters.append(p)
        dis_params_num.append(np.prod(p.size()))
    print('Dis Trainable params num : ', sum(dis_params_num))

    # setup optimizer
    if opt.adam:
        dis_optimizer = optim.Adam(dis_filtered_parameters,
                                   lr=opt.lr,
                                   betas=(opt.beta1, 0.999))
    else:
        dis_optimizer = optim.Adadelta(dis_filtered_parameters,
                                       lr=opt.lr,
                                       rho=opt.rho,
                                       eps=opt.eps)
    print("DisOptimizer:")
    print(dis_optimizer)
    ##---------------------------------------##
    """ final options """
    with open(os.path.join(opt.exp_dir, opt.exp_name, 'opt.txt'),
              'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0
    if opt.saved_synth_model != '':
        try:
            start_iter = int(
                opt.saved_synth_model.split('_')[-1].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    best_accuracy_ocr = -1
    best_norm_ED_ocr = -1
    iteration = start_iter
    # cntr=0
    while (True):
        # train part

        image_tensors_all, labels_1_all, labels_2_all = train_dataset.get_batch(
        )

        # ## comment
        # pdb.set_trace()
        # for imgCntr in range(image_tensors.shape[0]):
        #     save_image(tensor2im(image_tensors[imgCntr]),'temp/'+str(imgCntr)+'.png')
        # pdb.set_trace()
        # ###
        # print(cntr)
        # cntr+=1
        disCnt = int(image_tensors_all.size(0) / 2)
        image_tensors, image_tensors_real, labels_1, labels_2 = image_tensors_all[:disCnt], image_tensors_all[
            disCnt:disCnt +
            disCnt], labels_1_all[:disCnt], labels_2_all[:disCnt]

        image = image_tensors.to(device)
        image_real = image_tensors_real.to(device)
        text_1, length_1 = converter.encode(
            labels_1, batch_max_length=opt.batch_max_length)
        text_2, length_2 = converter.encode(
            labels_2, batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)

        images_recon_1, images_recon_2, style = model(image, text_1, text_2)

        if 'CTC' in opt.Prediction:

            #ocr training
            preds_ocr = ocrModel(image, text_1)
            preds_size_ocr = torch.IntTensor([preds_ocr.size(1)] * batch_size)
            preds_ocr = preds_ocr.log_softmax(2).permute(1, 0, 2)

            ocrCost_train = ocrCriterion(preds_ocr, text_1, preds_size_ocr,
                                         length_1)

            #dis training
            #Check: Using alternate real images
            disCost = opt.disWeight * 0.5 * (
                disModel.module.calc_dis_loss(images_recon_1.detach(),
                                              image_real) +
                disModel.module.calc_dis_loss(images_recon_2.detach(), image))

            #synth training
            preds_1 = ocrModel(images_recon_1, text_1)
            preds_size_1 = torch.IntTensor([preds_1.size(1)] * batch_size)
            preds_1 = preds_1.log_softmax(2).permute(1, 0, 2)

            preds_2 = ocrModel(images_recon_2, text_2)
            preds_size_2 = torch.IntTensor([preds_2.size(1)] * batch_size)
            preds_2 = preds_2.log_softmax(2).permute(1, 0, 2)

            ocrCost = 0.5 * (
                ocrCriterion(preds_1, text_1, preds_size_1, length_1) +
                ocrCriterion(preds_2, text_2, preds_size_2, length_2))

            #gen training
            disGenCost = 0.5 * (disModel.module.calc_gen_loss(images_recon_1) +
                                disModel.module.calc_gen_loss(images_recon_2))

        else:
            preds = model(image, text[:, :-1])  # align with Attention.forward
            target = text[:, 1:]  # without [GO] Symbol
            ocrCost = ocrCriterion(preds.view(-1, preds.shape[-1]),
                                   target.contiguous().view(-1))

        recCost = recCriterion(images_recon_1, image)
        styleRecCost = styleRecCriterion(
            model(images_recon_2, None, None, styleFlag=True), style.detach())

        cost = opt.ocrWeight * ocrCost + opt.reconWeight * recCost + opt.disWeight * disGenCost + opt.styleReconWeight * styleRecCost

        disModel.zero_grad()
        disCost.backward()
        torch.nn.utils.clip_grad_norm_(
            disModel.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        dis_optimizer.step()

        loss_avg_dis.add(disCost)

        model.zero_grad()
        ocrModel.zero_grad()
        disModel.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)

        #training OCR
        ocrModel.zero_grad()
        ocrCost_train.backward()
        torch.nn.utils.clip_grad_norm_(
            ocrModel.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        ocr_optimizer.step()

        loss_avg_ocr.add(ocrCost_train)

        #START HERE
        # validation part

        if (
                iteration + 1
        ) % opt.valInterval == 0 or iteration == 0:  # To see training progress, we also conduct validation when 'iteration == 0'

            #Save training images
            os.makedirs(os.path.join(opt.exp_dir, opt.exp_name, 'trainImages',
                                     str(iteration)),
                        exist_ok=True)
            for trImgCntr in range(batch_size):
                try:

                    save_image(
                        tensor2im(image[trImgCntr].detach()),
                        os.path.join(
                            opt.exp_dir, opt.exp_name, 'trainImages',
                            str(iteration),
                            str(trImgCntr) + '_input_' + labels_1[trImgCntr] +
                            '.png'))
                    save_image(
                        tensor2im(images_recon_1[trImgCntr].detach()),
                        os.path.join(
                            opt.exp_dir, opt.exp_name, 'trainImages',
                            str(iteration),
                            str(trImgCntr) + '_recon_' + labels_1[trImgCntr] +
                            '.png'))
                    save_image(
                        tensor2im(images_recon_2[trImgCntr].detach()),
                        os.path.join(
                            opt.exp_dir, opt.exp_name, 'trainImages',
                            str(iteration),
                            str(trImgCntr) + '_pair_' + labels_2[trImgCntr] +
                            '.png'))
                except:
                    print('Warning while saving training image')

            elapsed_time = time.time() - start_time
            # for log

            with open(os.path.join(opt.exp_dir, opt.exp_name, 'log_train.txt'),
                      'a') as log:
                model.eval()
                ocrModel.eval()
                disModel.eval()
                with torch.no_grad():
                    # valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                    #     model, criterion, valid_loader, converter, opt)

                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation_synth_lrw(
                        iteration, model, ocrModel, disModel, recCriterion,
                        styleRecCriterion, ocrCriterion, valid_loader,
                        converter, opt)
                model.train()
                ocrModel.train()
                disModel.train()

                # training loss and validation loss
                loss_log = f'[{iteration+1}/{opt.num_iter}] Train OCR loss: {loss_avg_ocr.val():0.5f}, Train Synth loss: {loss_avg.val():0.5f}, Train Dis loss: {loss_avg_dis.val():0.5f}, Valid OCR loss: {valid_loss[0]:0.5f}, Valid Synth loss: {valid_loss[1]:0.5f}, Valid Dis loss: {valid_loss[2]:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                loss_avg_ocr.reset()
                loss_avg.reset()
                loss_avg_dis.reset()

                current_model_log_ocr = f'{"Current_accuracy_OCR":17s}: {current_accuracy[0]:0.3f}, {"Current_norm_ED_OCR":17s}: {current_norm_ED[0]:0.2f}'
                current_model_log_1 = f'{"Current_accuracy_recon":17s}: {current_accuracy[1]:0.3f}, {"Current_norm_ED_recon":17s}: {current_norm_ED[1]:0.2f}'
                current_model_log_2 = f'{"Current_accuracy_pair":17s}: {current_accuracy[2]:0.3f}, {"Current_norm_ED_pair":17s}: {current_norm_ED[2]:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy[1] > best_accuracy:
                    best_accuracy = current_accuracy[1]
                    torch.save(
                        model.state_dict(),
                        os.path.join(opt.exp_dir, opt.exp_name,
                                     'best_accuracy.pth'))
                    torch.save(
                        disModel.state_dict(),
                        os.path.join(opt.exp_dir, opt.exp_name,
                                     'best_accuracy_dis.pth'))
                if current_norm_ED[1] > best_norm_ED:
                    best_norm_ED = current_norm_ED[1]
                    torch.save(
                        model.state_dict(),
                        os.path.join(opt.exp_dir, opt.exp_name,
                                     'best_norm_ED.pth'))
                    torch.save(
                        disModel.state_dict(),
                        os.path.join(opt.exp_dir, opt.exp_name,
                                     'best_norm_ED_dis.pth'))
                best_model_log = f'{"Best_accuracy_Recon":17s}: {best_accuracy:0.3f}, {"Best_norm_ED_Recon":17s}: {best_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy[0] > best_accuracy_ocr:
                    best_accuracy_ocr = current_accuracy[0]
                    torch.save(
                        ocrModel.state_dict(),
                        os.path.join(opt.exp_dir, opt.exp_name,
                                     'best_accuracy_ocr.pth'))
                if current_norm_ED[0] > best_norm_ED_ocr:
                    best_norm_ED_ocr = current_norm_ED[0]
                    torch.save(
                        ocrModel.state_dict(),
                        os.path.join(opt.exp_dir, opt.exp_name,
                                     'best_norm_ED_ocr.pth'))
                best_model_log_ocr = f'{"Best_accuracy_ocr":17s}: {best_accuracy_ocr:0.3f}, {"Best_norm_ED_ocr":17s}: {best_norm_ED_ocr:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log_ocr}\n{current_model_log_1}\n{current_model_log_2}\n{best_model_log_ocr}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":32s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                for gt_ocr, pred_ocr, confidence_ocr, gt_1, pred_1, confidence_1, gt_2, pred_2, confidence_2 in zip(
                        labels[0][:5], preds[0][:5], confidence_score[0][:5],
                        labels[1][:5], preds[1][:5], confidence_score[1][:5],
                        labels[2][:5], preds[2][:5], confidence_score[2][:5]):
                    if 'Attn' in opt.Prediction:
                        gt = gt[:gt.find('[s]')]
                        pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{"ocr"}: {gt_ocr:27s} | {pred_ocr:25s} | {confidence_ocr:0.4f}\t{str(pred_ocr == gt_ocr)}\n'
                    predicted_result_log += f'{"recon"}: {gt_1:25s} | {pred_1:25s} | {confidence_1:0.4f}\t{str(pred_1 == gt_1)}\n'
                    predicted_result_log += f'{"pair"}: {gt_2:26s} | {pred_2:25s} | {confidence_2:0.4f}\t{str(pred_2 == gt_2)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')

        # save model per 1e+5 iter.
        if (iteration + 1) % 1e+5 == 0:
            torch.save(
                model.state_dict(),
                os.path.join(opt.exp_dir, opt.exp_name,
                             'iter_{iteration+1}.pth'))
            torch.save(
                ocrModel.state_dict(),
                os.path.join(opt.exp_dir, opt.exp_name,
                             'iter_{iteration+1}_ocr.pth'))
            torch.save(
                disModel.state_dict(),
                os.path.join(opt.exp_dir, opt.exp_name,
                             'iter_{iteration+1}_dis.pth'))

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()
        iteration += 1
Ejemplo n.º 24
0
def train(opt):
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a')
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(
        root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    log.write(valid_dataset_log)
    print('-' * 80)
    log.write('-' * 80 + '\n')
    log.close()
    """ model configuration """
    # CTCLoss
    converter_ctc = CTCLabelConverter(opt.character)
    # Attention
    converter_atten = AttnLabelConverter(opt.character)
    opt.num_class_ctc = len(converter_ctc.character)
    opt.num_class_atten = len(converter_atten.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class_ctc, opt.num_class_atten, opt.batch_max_length,
          opt.Transformation, opt.FeatureExtraction, opt.SequenceModeling,
          opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p_: p_.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    print(optimizer)

    # use fp16 to train
    model = model.to(device)
    if opt.fp16:
        with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log:
            log.write('==> Enable fp16 training' + '\n')
        print('==> Enable fp16 training')
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

    # data parallel for multi-GPU
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model).to(device)
    model.train()
    # for i in model.module.Prediction_atten:
    #     i.to(device)
    # for i in model.module.Feat_Extraction.scr:
    #     i.to(device)
    if opt.saved_model != '':
        print(f'loading pretrained model from {opt.saved_model}')
        if opt.FT:
            model.load_state_dict(torch.load(opt.saved_model), strict=False)
        else:
            model.load_state_dict(torch.load(opt.saved_model))
    print("Model:")
    print(model)
    """ setup loss """
    criterion_ctc = torch.nn.CTCLoss(zero_infinity=True).to(device)
    criterion_atten = torch.nn.CrossEntropyLoss(ignore_index=0).to(
        device)  # ignore [GO] token = ignore index 0

    # loss averager
    loss_avg = Averager()
    """ final options """
    writer = SummaryWriter(f'./saved_models/{opt.exp_name}')
    # print(opt)
    with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0
    if opt.saved_model != '':
        try:
            start_iter = int(opt.saved_model.split('_')[-1].split('.')[0])
            print(f'continue to train, start_iter: {start_iter}')
        except:
            pass

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    iteration = start_iter

    # image_tensors, labels = train_dataset.get_batch()
    while True:
        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        batch_size = image.size(0)
        text_ctc, length_ctc = converter_ctc.encode(
            labels, batch_max_length=opt.batch_max_length)
        text_atten, length_atten = converter_atten.encode(
            labels, batch_max_length=opt.batch_max_length)

        # type tuple; (tensor, list);         text_atten[:, :-1]:align with Attention.forward
        preds_ctc, preds_atten = model(image, text_atten[:, :-1])
        # CTC Loss
        preds_size = torch.IntTensor([preds_ctc.size(1)] * batch_size)
        # _, preds_index = preds_ctc.max(2)
        # preds_str_ctc = converter_ctc.decode(preds_index.data, preds_size.data)
        preds_ctc = preds_ctc.log_softmax(2).permute(1, 0, 2)
        cost_ctc = 0.1 * criterion_ctc(preds_ctc, text_ctc, preds_size,
                                       length_ctc)

        # Attention Loss
        # preds_atten = [i[:, :text_atten.shape[1] - 1, :] for i in preds_atten]
        # # select max probabilty (greedy decoding) then decode index to character
        # preds_index_atten = [i.max(2)[1] for i in preds_atten]
        # length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device)
        # preds_str_atten = [converter_atten.decode(i, length_for_pred) for i in preds_index_atten]
        # preds_str_atten2 = preds_str_atten
        # preds_str_atten = []
        # for i in preds_str_atten2:  # prune after "end of sentence" token ([s])
        #     temp = []
        #     for j in i:
        #         j = j[:j.find('[s]')]
        #         temp.append(j)
        #     preds_str_atten.append(temp)
        # preds_str_atten = [j[:j.find('[s]')] for i in preds_str_atten for j in i]
        target = text_atten[:, 1:]  # without [GO] Symbol
        # cost_atten = 1.0 * criterion_atten(preds_atten.view(-1, preds_atten.shape[-1]), target.contiguous().view(-1))
        for index, pred in enumerate(preds_atten):
            if index == 0:
                cost_atten = 1.0 * criterion_atten(
                    pred.view(-1, pred.shape[-1]),
                    target.contiguous().view(-1))
            else:
                cost_atten += 1.0 * criterion_atten(
                    pred.view(-1, pred.shape[-1]),
                    target.contiguous().view(-1))
        # cost_atten = [1.0 * criterion_atten(pred.view(-1, pred.shape[-1]), target.contiguous().view(-1)) for pred in
        #               preds_atten]
        # cost_atten = criterion_atten(preds_atten.view(-1, preds_atten.shape[-1]), target.contiguous().view(-1))
        cost = cost_ctc + cost_atten
        writer.add_scalar('loss', cost.item(), global_step=iteration + 1)

        # cost = cost_ctc
        # cost = cost_atten
        if (iteration + 1) % 100 == 0:
            print('\riter: {:4d}\tloss: {:6.3f}\tavg: {:6.3f}'.format(
                iteration + 1, cost.item(), loss_avg.val()),
                  end='\n')
        else:
            print('\riter: {:4d}\tloss: {:6.3f}\tavg: {:6.3f}'.format(
                iteration + 1, cost.item(), loss_avg.val()),
                  end='')
        sys.stdout.flush()
        if cost < 0.001:
            print(f'iter: {iteration + 1}\tloss: {cost}')
            # aaaaaa = 0

        # model.zero_grad()
        optimizer.zero_grad()
        if torch.isnan(cost):
            print(f'iter: {iteration + 1}\tloss: {cost}\t==> Loss is NAN')
            sys.exit()
        elif torch.isinf(cost):
            print(f'iter: {iteration + 1}\tloss: {cost}\t==> Loss is INF')
            sys.exit()
        else:
            if opt.fp16:
                with amp.scale_loss(cost, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                cost.backward()
            torch.nn.utils.clip_grad_norm_(
                model.parameters(),
                opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)
        writer.add_scalar('loss_avg',
                          loss_avg.val(),
                          global_step=iteration + 1)
        # if loss_avg.val() <= 0.6:
        #     opt.grad_clip = 2
        # if loss_avg.val() <= 0.3:
        #     opt.grad_clip = 1

        # validation part
        if iteration == 0 or (
                iteration + 1
        ) % opt.valInterval == 0:  # To see training progress, we also conduct validation when 'iteration == 0'
            elapsed_time = time.time() - start_time
            # for log
            with open(f'./saved_models/{opt.exp_name}/log_train.txt',
                      'a') as log:
                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
                        model, criterion_atten, valid_loader, converter_atten,
                        opt)
                model.train()
                writer.add_scalar('accuracy',
                                  current_accuracy,
                                  global_step=iteration + 1)

                # training loss and validation loss
                loss_log = f'[{iteration + 1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
                loss_avg.reset()

                current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'

                # keep best accuracy model (on valid dataset)
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.exp_name}/best_accuracy.pth')
                if current_norm_ED > best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.exp_name}/best_norm_ED.pth')
                best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'

                loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                print(loss_model_log)
                log.write(loss_model_log + '\n')

                # show some predicted results
                dashed_line = '-' * 80
                head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F'
                predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'
                for gt, pred, confidence in zip(labels[:5], preds[:5],
                                                confidence_score[:5]):
                    gt = gt[:gt.find('[s]')]
                    pred = pred[:pred.find('[s]')]

                    predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n'
                predicted_result_log += f'{dashed_line}'
                print(predicted_result_log)
                log.write(predicted_result_log + '\n')

        # save model per 1e+5 iter.
        if (iteration + 1) % 1e+5 == 0:
            torch.save(
                model.state_dict(),
                f'./saved_models/{opt.exp_name}/iter_{iteration + 1}.pth')

        if (iteration + 1) == opt.num_iter:
            print('end the training')
            sys.exit()

        # if (iteration + 1) % opt.valInterval == 0:
        #     print(f'iter: {iteration + 1}\tloss: {cost}')
        iteration += 1
Ejemplo n.º 25
0
def demo(opt):
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)
    model = torch.nn.DataParallel(model).to(device)

    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model, map_location=device))

    # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
    AlignCollate_demo = AlignCollate(imgH=opt.imgH,
                                     imgW=opt.imgW,
                                     keep_ratio_with_pad=opt.PAD)
    demo_data = RawDataset(root=opt.image_folder, opt=opt)  # use RawDataset
    demo_loader = torch.utils.data.DataLoader(demo_data,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=int(opt.workers),
                                              collate_fn=AlignCollate_demo,
                                              pin_memory=True)

    # predict
    model.eval()
    with torch.no_grad():
        for image_tensors, image_path_list in demo_loader:
            batch_size = image_tensors.size(0)
            image = image_tensors.to(device)
            # For max length prediction
            length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                              batch_size).to(device)
            text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                             1).fill_(0).to(device)

            if 'CTC' in opt.Prediction:
                preds = model(image, text_for_pred)

                # Select max probabilty (greedy decoding) then decode index to character
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                _, preds_index = preds.max(2)
                preds_index = preds_index.view(-1)
                preds_str = converter.decode(preds_index.data, preds_size.data)

            else:
                preds = model(image, text_for_pred, is_train=False)

                # select max probabilty (greedy decoding) then decode index to character
                _, preds_index = preds.max(2)
                preds_str = converter.decode(preds_index, length_for_pred)

            log = open('./log_demo_result.txt', 'a')
            dashed_line = '-' * 80
            head = "image_path" + "\t" + "predicted_labels" + "\t" + "confidence score"

            print(dashed_line + "\n" + head + "\n" + dashed_line)
            log.write(dashed_line + "\n" + head + "\n" + dashed_line)

            preds_prob = F.softmax(preds, dim=2)
            preds_max_prob, _ = preds_prob.max(dim=2)
            for img_name, pred, pred_max_prob in zip(image_path_list,
                                                     preds_str,
                                                     preds_max_prob):
                if 'Attn' in opt.Prediction:
                    pred_EOS = pred.find('[s]')
                    pred = pred[:
                                pred_EOS]  # prune after "end of sentence" token ([s])
                    pred_max_prob = pred_max_prob[:pred_EOS]

                # calculate confidence score (= multiply of pred_max_prob)
                confidence_score = pred_max_prob.cumprod(dim=0)[-1]

                print(img_name,
                      pred)  #(img_name +"\t"+ pred +"\t"+ confidence_score)
                log.write(img_name + "\t" + pred + "\n")

            log.close()
Ejemplo n.º 26
0
def test(opt):
    lib.print_model_settings(locals().copy())

    if 'Attn' in opt.Prediction:
        converter = AttnLabelConverter(opt.character)
        text_len = opt.batch_max_length + 2
    else:
        converter = CTCLabelConverter(opt.character)
        text_len = opt.batch_max_length

    opt.classes = converter.character
    """ dataset preparation """
    if not opt.data_filtering_off:
        print(
            'Filtering the images containing characters which are not in opt.character'
        )
        print(
            'Filtering the images whose label is longer than opt.batch_max_length'
        )
        # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130

    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)

    valid_dataset = LmdbDataset(root=opt.test_data, opt=opt)
    test_data_sampler = data_sampler(valid_dataset,
                                     shuffle=False,
                                     distributed=False)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        False,  # 'True' to check training progress with validation function.
        sampler=test_data_sampler,
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True,
        drop_last=False)

    print('-' * 80)

    opt.num_class = len(converter.character)

    ocrModel = ModelV1(opt).to(device)

    ## Loading pre-trained files
    print(f'loading pretrained ocr model from {opt.saved_ocr_model}')
    checkpoint = torch.load(opt.saved_ocr_model,
                            map_location=lambda storage, loc: storage)
    ocrModel.load_state_dict(checkpoint)

    evalCntr = 0
    fCntr = 0

    c1_s1_input_correct = 0.0
    c1_s1_input_ed_correct = 0.0
    # pdb.set_trace()

    for vCntr, (image_input_tensors, labels_gt) in enumerate(valid_loader):
        print(vCntr)

        image_input_tensors = image_input_tensors.to(device)
        text_gt, length_gt = converter.encode(
            labels_gt, batch_max_length=opt.batch_max_length)

        with torch.no_grad():
            currBatchSize = image_input_tensors.shape[0]
            # text_for_pred = torch.LongTensor(opt.batch_size, opt.batch_max_length + 1).fill_(0).to(device)
            length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                              currBatchSize).to(device)
            #Run OCR prediction
            if 'CTC' in opt.Prediction:
                preds = ocrModel(image_input_tensors, text_gt, is_train=False)
                preds_size = torch.IntTensor([preds.size(1)] *
                                             image_input_tensors.shape[0])
                _, preds_index = preds.max(2)
                preds_str_gt_1 = converter.decode(preds_index.data,
                                                  preds_size.data)

            else:
                preds = ocrModel(
                    image_input_tensors, text_gt[:, :-1],
                    is_train=False)  # align with Attention.forward
                _, preds_index = preds.max(2)
                preds_str_gt_1 = converter.decode(preds_index, length_for_pred)
                for idx, pred in enumerate(preds_str_gt_1):
                    pred_EOS = pred.find('[s]')
                    preds_str_gt_1[
                        idx] = pred[:
                                    pred_EOS]  # prune after "end of sentence" token ([s])

        for trImgCntr in range(image_input_tensors.shape[0]):
            #ocr accuracy
            # for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob):
            c1_s1_input_gt = labels_gt[trImgCntr]
            c1_s1_input_ocr = preds_str_gt_1[trImgCntr]

            if c1_s1_input_gt == c1_s1_input_ocr:
                c1_s1_input_correct += 1

            # ICDAR2019 Normalized Edit Distance
            if len(c1_s1_input_gt) == 0 or len(c1_s1_input_ocr) == 0:
                c1_s1_input_ed_correct += 0
            elif len(c1_s1_input_gt) > len(c1_s1_input_ocr):
                c1_s1_input_ed_correct += 1 - edit_distance(
                    c1_s1_input_ocr, c1_s1_input_gt) / len(c1_s1_input_gt)
            else:
                c1_s1_input_ed_correct += 1 - edit_distance(
                    c1_s1_input_ocr, c1_s1_input_gt) / len(c1_s1_input_ocr)

            evalCntr += 1

            fCntr += 1

    avg_c1_s1_input_wer = c1_s1_input_correct / float(evalCntr)
    avg_c1_s1_input_cer = c1_s1_input_ed_correct / float(evalCntr)

    # if not(opt.realVaData):
    with open(os.path.join(opt.exp_dir, opt.exp_name, 'log_test.txt'),
              'a') as log:
        # training loss and validation loss

        loss_log = f'Word Acc: {avg_c1_s1_input_wer:0.5f}, Test Input Char Acc: {avg_c1_s1_input_cer:0.5f}'

        print(loss_log)
        log.write(loss_log + "\n")
Ejemplo n.º 27
0
def train(opt):
    """ dataset preparation """
    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset = hierarchical_dataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=opt.batch_size,
        shuffle=
        True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid,
        pin_memory=True)
    print('-' * 80)
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    model = torch.nn.DataParallel(model).to(device)
    model.train()
    if opt.continue_model != '':
        print(f'loading pretrained model from {opt.continue_model}')
        model.load_state_dict(torch.load(opt.continue_model))
    print("Model:")
    print(model)
    """ setup loss """
    if 'CTC' in opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
            device)  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    print(optimizer)
    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.experiment_name}/opt.txt',
              'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0
    if opt.continue_model != '':
        start_iter = int(opt.continue_model.split('_')[-1].split('.')[0])
        print(f'continue to train, start_iter: {start_iter}')

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = 1e+6
    i = start_iter

    while (True):
        # train part
        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)
        text, length = converter.encode(labels,
                                        batch_max_length=opt.batch_max_length)
        batch_size = image.size(0)

        if 'CTC' in opt.Prediction:
            preds = model(image, text).log_softmax(2)
            preds_size = torch.IntTensor([preds.size(1)] *
                                         batch_size).to(device)
            preds = preds.permute(1, 0, 2)  # to use CTCLoss format

            # To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss
            # https://github.com/jpuigcerver/PyLaia/issues/16
            torch.backends.cudnn.enabled = False
            cost = criterion(preds, text, preds_size, length)
            torch.backends.cudnn.enabled = True

        else:
            preds = model(image, text[:, :-1])  # align with Attention.forward
            target = text[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.view(-1, preds.shape[-1]),
                             target.contiguous().view(-1))

        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)

        # validation part
        if i % opt.valInterval == 0:
            elapsed_time = time.time() - start_time
            print(
                f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}'
            )
            # for log
            with open(f'./saved_models/{opt.experiment_name}/log_train.txt',
                      'a') as log:
                log.write(
                    f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n'
                )
                loss_avg.reset()

                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation(
                        model, criterion, valid_loader, converter, opt)
                model.train()

                for pred, gt in zip(preds[:5], labels[:5]):
                    if 'Attn' in opt.Prediction:
                        pred = pred[:pred.find('[s]')]
                        gt = gt[:gt.find('[s]')]
                    print(f'{pred:20s}, gt: {gt:20s},   {str(pred == gt)}')
                    log.write(
                        f'{pred:20s}, gt: {gt:20s},   {str(pred == gt)}\n')

                valid_log = f'[{i}/{opt.num_iter}] valid loss: {valid_loss:0.5f}'
                valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}'
                print(valid_log)
                log.write(valid_log + '\n')

                # keep best accuracy model
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.experiment_name}/best_accuracy.pth'
                    )
                if current_norm_ED < best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(
                        model.state_dict(),
                        f'./saved_models/{opt.experiment_name}/best_norm_ED.pth'
                    )
                best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}'
                print(best_model_log)
                log.write(best_model_log + '\n')

        # save model per 1e+5 iter.
        if (i + 1) % 1e+5 == 0:
            torch.save(model.state_dict(),
                       f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')

        if i == opt.num_iter:
            print('end the training')
            sys.exit()
        i += 1
Ejemplo n.º 28
0
def demo(opt):
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)
    model = torch.nn.DataParallel(model).to(device)
    # model = model.to(device)
    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model, map_location=device))
    # model.load_state_dict(copy_state_dict(torch.load(opt.saved_model, map_location=device)))

    # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
    AlignCollate_demo = AlignCollate(imgH=opt.imgH,
                                     imgW=opt.imgW,
                                     keep_ratio_with_pad=opt.PAD)
    demo_data = RawDataset(root=opt.image_folder, opt=opt)  # use RawDataset
    demo_loader = torch.utils.data.DataLoader(demo_data,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=int(opt.workers),
                                              collate_fn=AlignCollate_demo,
                                              pin_memory=True)

    # predict
    model.eval()
    fail_count, sample_count = 0, 0
    with torch.no_grad():
        for image_tensors, image_path_list in demo_loader:
            batch_size = image_tensors.size(0)
            image = image_tensors.to(device)
            # For max length prediction
            length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                              batch_size).to(device)
            text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length +
                                             1).fill_(0).to(device)

            if 'CTC' in opt.Prediction:
                preds = model(image, text_for_pred)

                # Select max probabilty (greedy decoding) then decode index to character
                preds_size = torch.IntTensor([preds.size(1)] * batch_size)
                _, preds_index = preds.max(2)
                # preds_index = preds_index.view(-1)
                preds_str = converter.decode(preds_index, preds_size)

            else:
                preds = model(image, text_for_pred, is_train=False)

                # select max probabilty (greedy decoding) then decode index to character
                _, preds_index = preds.max(2)
                preds_str = converter.decode(preds_index, length_for_pred)

            log = open(f'./log_demo_result.txt', 'a')
            dashed_line = '-' * 80
            head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score'

            print(f'{dashed_line}\n{head}\n{dashed_line}')
            log.write(f'{dashed_line}\n{head}\n{dashed_line}\n')

            preds_prob = F.softmax(preds, dim=2)
            preds_max_prob, _ = preds_prob.max(dim=2)
            for img_name, pred, pred_max_prob in zip(image_path_list,
                                                     preds_str,
                                                     preds_max_prob):
                if 'Attn' in opt.Prediction:
                    pred_EOS = pred.find('[s]')
                    pred = pred[:
                                pred_EOS]  # prune after "end of sentence" token ([s])
                    pred_max_prob = pred_max_prob[:pred_EOS]

                if pred_max_prob.shape[0] > 0:
                    # calculate confidence score (= multiply of pred_max_prob)
                    confidence_score = pred_max_prob.cumprod(dim=0)[-1]
                else:
                    confidence_score = 0.0

                gt = img_name.split('_L_')[1]
                gt = gt.split('.')[0]
                pred = pred.split('.')[0]
                # if img_name.find('1_225427_L_대전출입국관리사무소_L_21.png') >=0 :
                #     print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}')

                if gt.split('(')[0] != pred.split('(')[0]:
                    fail_count += 1
                    log.write(
                        f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n'
                    )
                else:
                    print(
                        f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}')
                    pass
                sample_count += 1
        log.close()
        print(f'total accuracy: {(sample_count-fail_count)/sample_count:.2f}')