Example #1
0
 def load_model(self):
     if 'CTC' in self.opt.Prediction:
         self.converter = CTCLabelConverter(self.opt.character,
                                            self.opt.subword)
     else:
         self.converter = AttnLabelConverter(self.opt.character,
                                             self.opt.subword)
     self.opt.num_class = len(self.converter.character)
     if self.opt.rgb:
         self.opt.input_channel = 3
     self.model = Model(self.opt)
     print('model input parameters', self.opt.imgH, self.opt.imgW,
           self.opt.num_fiducial, self.opt.input_channel,
           self.opt.output_channel, self.opt.hidden_size,
           self.opt.num_class, self.opt.batch_max_length,
           self.opt.Transformation, self.opt.FeatureExtraction,
           self.opt.SequenceModeling, self.opt.Prediction)
     self.model = torch.nn.DataParallel(self.model)
     if torch.cuda.is_available():
         self.model = self.model.cuda()
     # load model
     print('loading pretrained model from %s' % self.opt.saved_model)
     if torch.cuda.is_available():
         self.model.load_state_dict(torch.load(self.opt.saved_model))
     else:
         self.model.load_state_dict(
             torch.load(self.opt.saved_model, map_location="cpu"))
     self.model.eval()
Example #2
0
def test(opt):
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)
    model = torch.nn.DataParallel(model).cuda()

    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    model.load_state_dict(torch.load(opt.saved_model))
    opt.experiment_name = '_'.join(opt.saved_model.split('/')[1:])
    # print(model)
    """ keep evaluation model and result logs """
    os.makedirs(f'./result/{opt.experiment_name}', exist_ok=True)
    os.system(f'cp {opt.saved_model} ./result/{opt.experiment_name}/')
    """ setup loss """
    if 'CTC' in opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).cuda()
    else:
        criterion = torch.nn.CrossEntropyLoss(
            ignore_index=0).cuda()  # ignore [GO] token = ignore index 0
    """ evaluation """
    model.eval()
    if opt.benchmark_all_eval:  # evaluation with 10 benchmark evaluation datasets
        benchmark_all_eval(model, criterion, converter, opt)
    else:
        AlignCollate_evaluation = AlignCollate(imgH=opt.imgH,
                                               imgW=opt.imgW,
                                               keep_ratio_with_pad=opt.PAD)
        eval_data = hierarchical_dataset(root=opt.eval_data, opt=opt)
        evaluation_loader = torch.utils.data.DataLoader(
            eval_data,
            batch_size=opt.batch_size,
            shuffle=False,
            num_workers=int(opt.workers),
            collate_fn=AlignCollate_evaluation,
            pin_memory=True)
        _, accuracy_by_best_model, _, _, _, _, _ = validation(
            model, criterion, evaluation_loader, converter, opt)

        print(accuracy_by_best_model)
        with open(
                './result/{0}/log_evaluation.txt'.format(opt.experiment_name),
                'a') as log:
            log.write(str(accuracy_by_best_model) + '\n')
def train(opt):
    """ dataset preparation """
    opt.select_data = opt.select_data.split('-')
    opt.batch_ratio = opt.batch_ratio.split('-')
    train_dataset = Batch_Balanced_Dataset(opt)

    AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
    valid_dataset = hierarchical_dataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=opt.batch_size,
        shuffle=True,  # 'True' to check training progress with validation function.
        num_workers=int(opt.workers),
        collate_fn=AlignCollate_valid, pin_memory=True)
    print('-' * 80)

    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel,
          opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction,
          opt.SequenceModeling, opt.Prediction)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initialized')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:  # for batchnorm.
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # data parallel for multi-GPU
    model = torch.nn.DataParallel(model).cuda()
    model.train()
    if opt.continue_model != '':
        if opt.without_prediction:
            load_model_without_prediction(opt.continue_model, model)
            print(f'loading pretrained model from {opt.continue_model}, without prediction layer')
        else:
            print(f'loading pretrained model from {opt.continue_model}')
            model.load_state_dict(torch.load(opt.continue_model))
    print("Model:")
    print(model)

    """ setup loss """
    if 'CTC' in opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).cuda()
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).cuda()  # ignore [GO] token = ignore index 0
    # loss averager
    loss_avg = Averager()

    # filter that only require gradient decent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))
    # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]

    # setup optimizer
    if opt.adam:
        optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps)
    print("Optimizer:")
    print(optimizer)

    """ final options """
    # print(opt)
    with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)}: {str(v)}\n'
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)

    """ start training """
    start_iter = 0
    if opt.continue_model != '':
        print(f'continue to train, start_iter: {start_iter}')

    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = 1e+6
    i = start_iter

    while True:
        # train part
        for p in model.parameters():
            p.requires_grad = True

        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.cuda()
        text, length = converter.encode(labels)
        batch_size = image.size(0)

        if 'CTC' in opt.Prediction:
            preds = model(image, text).log_softmax(2)
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            preds = preds.permute(1, 0, 2)  # to use CTCLoss format
            cost = criterion(preds, text, preds_size, length)

        else:
            preds = model(image, text)
            target = text[:, 1:]  # without [GO] Symbol
            cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))

        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()

        loss_avg.add(cost)

        # validation part
        if i % opt.valInterval == 0:
            elapsed_time = time.time() - start_time
            logging.info(f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}')
            # for log
            with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log:
                log.write(f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n')
                loss_avg.reset()

                model.eval()
                valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation(
                    model, criterion, valid_loader, converter, opt)
                model.train()

                for pred, gt in zip(preds[:5], labels[:5]):
                    if 'Attn' in opt.Prediction:
                        pred = pred[:pred.find('[s]')]
                        gt = gt[:gt.find('[s]')]
                    print(f'{pred:20s}, gt: {gt:20s},   {str(pred == gt)}')
                    log.write(f'{pred:20s}, gt: {gt:20s},   {str(pred == gt)}\n')

                valid_log = f'[{i}/{opt.num_iter}] valid loss: {valid_loss:0.5f}'
                valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}'
                log.write(valid_log + '\n')

                # keep best accuracy model
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/mtl_best_accuracy.pth')
                if current_norm_ED < best_norm_ED:
                    best_norm_ED = current_norm_ED
                    torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth')
                best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}'
                logging.info(best_model_log)
                log.write(best_model_log + '\n')

        # save model per 1e+5 iter.
        if (i + 1) % 50000 == 0:
            torch.save(
                model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth')

        if i == opt.num_iter:
            logging.info('end the training')
            sys.exit()
        i += 1
Example #4
0
def train(opt):
    #logging.info(opt)
    train_dataset = Batch_Dataset(opt)
    AlignCollate_valid = AlignCollate(imgH=opt.imgH,
                                      imgW=opt.imgW,
                                      keep_ratio_with_pad=opt.PAD)
    valid_dataset = LmdbDataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=opt.batch_size,
                                               shuffle=True,
                                               num_workers=int(opt.workers),
                                               collate_fn=AlignCollate_valid,
                                               pin_memory=True)
    print('-' * 80)
    """ model configuration """

    ctc_converter = CTCLabelConverter(opt.character, opt.subword)
    attn_converter = AttnLabelConverter(opt.character, opt.subword,
                                        opt.batch_max_length)

    opt.num_class = len(attn_converter.character)
    opt.ctc_num_class = len(ctc_converter.character)
    print("ctc num class {}".format(len(ctc_converter.character)))
    print("attention num class {}".format(len(attn_converter.character)))

    if opt.rgb:
        opt.input_channel = 3

    model = MyModel(opt)

    # weight initialization
    for name, param in model.named_parameters():
        if 'localization_fc2' in name:
            print('Skip {name} as it is already initialized'.format(name))
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)
        except Exception as e:
            if 'weight' in name:
                param.data.fill_(1)
            continue

    model = torch.nn.DataParallel(model).to(device)

    model.train()
    if opt.continue_model != '':
        print('loading pretrained model from {}'.format(opt.continue_model))
        model.load_state_dict(torch.load(opt.continue_model))
    """ setup loss """
    ctc_criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    attn_criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)

    loss_avg = Averager()
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Trainable params num : ', sum(params_num))

    if opt.adam:
        optimizer = optim.Adam(filtered_parameters,
                               lr=opt.lr,
                               betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filtered_parameters,
                                   lr=opt.lr,
                                   rho=opt.rho,
                                   eps=opt.eps)
    print("Optimizer:")
    print(optimizer)
    """ final options """
    with open(osj(opt.outPath, '{}/opt.txt'.format(opt.experiment_name)),
              'a') as opt_file:
        opt_log = '------------ Options -------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += '{}: {}\n'.format(str(k), str(v))
        opt_log += '---------------------------------------\n'
        print(opt_log)
        opt_file.write(opt_log)
    """ start training """
    start_iter = 0
    if opt.continue_model != '':
        print('continue to train, start_iter: {}'.format(start_iter))

    start_time = time.time()
    best_accuracy = -1
    i = start_iter

    while True:
        # train part
        for p in model.parameters():
            p.requires_grad = True

        image_tensors, labels = train_dataset.get_batch()
        image = image_tensors.to(device)

        ctc_text, ctc_length = ctc_converter.encode(labels)
        attn_text, attn_length = attn_converter.encode(labels)
        batch_size = image.size(0)
        # ctc loss
        ctc_preds, attn_preds = model(image, attn_text)
        ctc_preds = ctc_preds.log_softmax(2)
        preds_size = torch.IntTensor([ctc_preds.size(1)] * batch_size)
        ctc_preds = ctc_preds.permute(1, 0, 2)
        ctc_cost = ctc_criterion(ctc_preds, ctc_text, preds_size, ctc_length)
        # attn loss
        target = attn_text[:, 1:]
        attn_cost = attn_criterion(attn_preds.view(-1, attn_preds.shape[-1]),
                                   target.contiguous().view(-1))
        cost = opt.ctc_weight * ctc_cost + (1.0 - opt.ctc_weight) * attn_cost

        model.zero_grad()
        cost.backward()
        torch.nn.utils.clip_grad_norm_(
            model.parameters(),
            opt.grad_clip)  # gradient clipping with 5 (Default)
        optimizer.step()
        loss_avg.add(cost)
        # validation part
        if i % opt.valInterval == 0:
            elapsed_time = time.time() - start_time
            logging.info('[{}/{}] Loss: {:0.5f} elapsed_time: {:0.5f}'.format(
                i, opt.num_iter, loss_avg.val(), elapsed_time))
            # for log
            with open(
                    osj(opt.outPath,
                        '{}/log_train.txt'.format(opt.experiment_name)),
                    'a') as log:
                log.write(
                    '[{}/{}] Loss: {:0.5f} elapsed_time: {:0.5f}\n'.format(
                        i, opt.num_iter, loss_avg.val(), elapsed_time))
                loss_avg.reset()

                model.eval()
                with torch.no_grad():
                    valid_loss, current_accuracy, ctc_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data \
                        = mtl_validation(model, ctc_criterion, attn_criterion, valid_loader, ctc_converter, attn_converter, opt)
                model.train()

                for pred, gt in zip(preds[:5], labels[:5]):
                    pred = pred[:pred.find('[s]')]
                    gt = gt[:gt.find('[s]')]
                    print('{:20s}, gt: {:20s},   {}'.format(
                        pred, gt, str(pred == gt)))
                    log.write('{:20s}, gt: {:20s},   {}\n'.format(
                        pred, gt, str(pred == gt)))

                valid_log = '[{}/{}] valid loss: {:0.5f}'.format(
                    i, opt.num_iter, valid_loss)
                valid_log += ' accuracy: {:0.3f}'.format(current_accuracy)

                log.write(valid_log + '\n')

                # save best accuracy model
                if current_accuracy > best_accuracy:
                    best_accuracy = current_accuracy
                    torch.save(
                        model.state_dict(),
                        osj(opt.outPath, '{}/best_accuracy.pth'.format(
                            opt.experiment_name)))

                best_model_log = 'best_accuracy: {:0.3f}'.format(best_accuracy)
                logging.info(best_model_log)
                log.write(best_model_log + '\n')

        if (i + 1) % 50000 == 0:
            torch.save(
                model.state_dict(),
                osj(opt.outPath,
                    '{}/iter_{}.pth'.format(opt.experiment_name, i + 1)))

        if i == opt.num_iter:
            logging.info('end the training')
            sys.exit()
        i += 1
Example #5
0
class OcrRec:
    def __init__(self, opt=None):
        self.max_length = 25
        self.opt = ConfigOpt()
        if opt:
            self.opt = opt
        self.batch_size = 1
        self.model = None
        self.converter = None
        self.load_model()

    def load_model(self):
        if 'CTC' in self.opt.Prediction:
            self.converter = CTCLabelConverter(self.opt.character,
                                               self.opt.subword)
        else:
            self.converter = AttnLabelConverter(self.opt.character,
                                                self.opt.subword)
        self.opt.num_class = len(self.converter.character)
        if self.opt.rgb:
            self.opt.input_channel = 3
        self.model = Model(self.opt)
        print('model input parameters', self.opt.imgH, self.opt.imgW,
              self.opt.num_fiducial, self.opt.input_channel,
              self.opt.output_channel, self.opt.hidden_size,
              self.opt.num_class, self.opt.batch_max_length,
              self.opt.Transformation, self.opt.FeatureExtraction,
              self.opt.SequenceModeling, self.opt.Prediction)
        self.model = torch.nn.DataParallel(self.model)
        if torch.cuda.is_available():
            self.model = self.model.cuda()
        # load model
        print('loading pretrained model from %s' % self.opt.saved_model)
        if torch.cuda.is_available():
            self.model.load_state_dict(torch.load(self.opt.saved_model))
        else:
            self.model.load_state_dict(
                torch.load(self.opt.saved_model, map_location="cpu"))
        self.model.eval()

    def text_rec(self, img):
        """
        resize PIL image to fixed height, keep width/height ratio
        do inference
        :param img:
        :return:
        """
        if isinstance(img, str) and os.path.isfile(img):
            img = Image.open(img)
            img = img.convert('L')
            import PIL.ImageOps
            # img = PIL.ImageOps.invert(img)
        if not img.mode == 'L':
            img = img.convert('L')
        ratio = self.opt.imgH / img.size[1]
        target_w = int(img.size[0] * ratio)
        transformer = InferResizeNormalize((target_w, self.opt.imgH))
        # image_tensors = [transformer(img)]
        # image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0)
        img = transformer(img)
        img = img.view(1, *img.size())
        img = Variable(img)
        with torch.no_grad():
            if torch.cuda.is_available():
                img = img.cuda()
                length_for_pred = torch.cuda.IntTensor(
                    [self.opt.batch_max_length] * self.batch_size)
                text_for_pred = torch.cuda.LongTensor(
                    self.batch_size, self.opt.batch_max_length + 1).fill_(0)
            else:
                length_for_pred = torch.IntTensor([self.opt.batch_max_length] *
                                                  self.batch_size)
                text_for_pred = torch.LongTensor(
                    self.batch_size, self.opt.batch_max_length + 1).fill_(0)

            preds = self.model(img, text_for_pred, is_train=False)
            # select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            preds_str = self.converter.decode(preds_index, length_for_pred)
            preds_str = [pred[:pred.find('[s]')] for pred in preds_str]
            # print("pred:", preds_str[0])
        return preds_str[0]
def demo(opt):
    """ model configuration """
    if 'CTC' in opt.Prediction:
        converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial,
          opt.input_channel, opt.output_channel, opt.hidden_size,
          opt.num_class, opt.batch_max_length, opt.Transformation,
          opt.FeatureExtraction, opt.SequenceModeling, opt.Prediction)

    model = torch.nn.DataParallel(model)
    if torch.cuda.is_available():
        model = model.cuda()

    # load model
    print('loading pretrained model from %s' % opt.saved_model)
    if torch.cuda.is_available():
        model.load_state_dict(torch.load(opt.saved_model))
    else:
        model.load_state_dict(torch.load(opt.saved_model, map_location="cpu"))

    # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo
    AlignCollate_demo = AlignCollate(imgH=opt.imgH,
                                     imgW=opt.imgW,
                                     keep_ratio_with_pad=opt.PAD)
    demo_data = RawDataset(root=opt.image_folder, opt=opt)  # use RawDataset
    demo_loader = torch.utils.data.DataLoader(demo_data,
                                              batch_size=opt.batch_size,
                                              shuffle=False,
                                              num_workers=int(opt.workers),
                                              collate_fn=AlignCollate_demo,
                                              pin_memory=True)

    # predict
    model.eval()
    for image_tensors, image_path_list in demo_loader:
        batch_size = image_tensors.size(0)
        with torch.no_grad():
            if torch.cuda.is_available():
                image = image_tensors.cuda()
                # For max length prediction
                length_for_pred = torch.cuda.IntTensor([opt.batch_max_length] *
                                                       batch_size)
                text_for_pred = torch.cuda.LongTensor(
                    batch_size, opt.batch_max_length + 1).fill_(0)
            else:
                image = image_tensors
                # For max length prediction
                length_for_pred = torch.IntTensor([opt.batch_max_length] *
                                                  batch_size)
                text_for_pred = torch.LongTensor(
                    batch_size, opt.batch_max_length + 1).fill_(0)

        if 'CTC' in opt.Prediction:
            preds = model(image, text_for_pred).log_softmax(2)

            # Select max probabilty (greedy decoding) then decode index to character
            preds_size = torch.IntTensor([preds.size(1)] * batch_size)
            _, preds_index = preds.permute(1, 0, 2).max(2)
            preds_index = preds_index.transpose(1, 0).contiguous().view(-1)
            preds_str = converter.decode(preds_index.data, preds_size.data)

        else:
            preds = model(image, text_for_pred, is_train=False)

            # select max probabilty (greedy decoding) then decode index to character
            _, preds_index = preds.max(2)
            preds_str = converter.decode(preds_index, length_for_pred)

        print('-' * 80)
        print('image_path\tpredicted_labels')
        print('-' * 80)
        for img_name, pred in zip(image_path_list, preds_str):
            if 'Attn' in opt.Prediction:
                pred = pred[:pred.find(
                    '[s]')]  # prune after "end of sentence" token ([s])

            print(f'{img_name}\t{pred}')
class OcrRec:
    def __init__(self, opt=None):
        self.max_length = 25
        self.opt = ConfigOpt()
        if opt:
            self.opt = opt
        self.batch_size = 1
        self.model = None

        self.cur_path = os.path.abspath(os.path.dirname(__file__))
        self.opt.saved_model = os.path.join(self.cur_path,
                                            "models/mtl_best_accuracy.pth")
        self.opt.Transformation = 'None'  # None|TPS
        self.opt.FeatureExtraction = 'ResNet'  # VGG|RCNN|ResNet
        self.opt.SequenceModeling = 'BiLSTM'  # None|BiLSTM
        self.opt.Prediction = 'CTC'  # CTC|Attn (use CTC or Attention in inference stage)
        # self.opt.output_channel = 512
        # self.opt.hidden_size = 256
        self.opt.output_channel = 768
        self.opt.hidden_size = 384
        self.opt.mtl = True

        self.ctc_converter = None
        self.attn_converter = None
        self.load_model()

    def load_model(self):
        if 'CTC' in self.opt.Prediction:
            self.ctc_converter = CTCLabelConverter(self.opt.character)
            self.opt.ctc_num_class = len(self.ctc_converter.character)
            self.opt.num_class = self.opt.ctc_num_class + 1
        else:
            self.attn_converter = AttnLabelConverter(self.opt.character)
            self.opt.num_class = len(self.attn_converter.character)
            self.opt.ctc_num_class = self.opt.num_class - 1
        if self.opt.rgb:
            self.opt.input_channel = 3
        self.model = Model(self.opt)
        print('model input parameters', self.opt.imgH, self.opt.imgW,
              self.opt.num_fiducial, self.opt.input_channel,
              self.opt.output_channel, self.opt.hidden_size,
              self.opt.num_class, self.opt.batch_max_length,
              self.opt.Transformation, self.opt.FeatureExtraction,
              self.opt.SequenceModeling, self.opt.Prediction)
        print(f"=====Use {self.opt.Prediction} prediction result=====")
        self.model = torch.nn.DataParallel(self.model)
        if torch.cuda.is_available():
            self.model = self.model.cuda()
        # load model
        print('loading pretrained model from %s' % self.opt.saved_model)
        if torch.cuda.is_available():
            self.model.load_state_dict(torch.load(self.opt.saved_model))
        else:
            self.model.load_state_dict(
                torch.load(self.opt.saved_model, map_location="cpu"))
        self.model.eval()

    def text_rec(self, img):
        """
        resize PIL image to fixed height, keep width/height ratio
        do inference
        :param img:
        :return:
        """
        if isinstance(img, str) and os.path.isfile(img):
            img = Image.open(img)
            img = img.convert('L')
            import PIL.ImageOps
            # img = PIL.ImageOps.invert(img)
        if not img.mode == 'L':
            img = img.convert('L')
        ratio = self.opt.imgH / img.size[1]
        target_w = int(img.size[0] * ratio)
        transformer = InferResizeNormalize((target_w, self.opt.imgH))
        img = transformer(img)
        img = img.view(1, *img.size())
        img = Variable(img)
        with torch.no_grad():
            if torch.cuda.is_available():
                img = img.cuda()
                length_for_pred = torch.cuda.IntTensor(
                    [self.opt.batch_max_length] * self.batch_size)
                text_for_pred = torch.cuda.LongTensor(
                    self.batch_size, self.opt.batch_max_length + 1).fill_(0)
            else:
                length_for_pred = torch.IntTensor([self.opt.batch_max_length] *
                                                  self.batch_size)
                text_for_pred = torch.LongTensor(
                    self.batch_size, self.opt.batch_max_length + 1).fill_(0)
            if 'CTC' in self.opt.Prediction:
                preds, _ = self.model(img, text_for_pred, is_train=False)
                preds = preds.softmax(2)
                # Select max probabilty (greedy decoding) then decode index to character
                preds_size = torch.IntTensor([preds.size(1)] * self.batch_size)
                preds_prob_vals, preds_index = preds.permute(1, 0, 2).max(2)
                preds_index = preds_index.transpose(1, 0).contiguous().view(-1)
                preds_str = self.ctc_converter.decode(preds_index.data,
                                                      preds_size.data)
            elif 'Attn' in self.opt.Prediction:
                _, preds = self.model(img, text_for_pred, is_train=False)
                # select max probabilty (greedy decoding) then decode index to character
                _, preds_index = preds.max(2)
                preds_str = self.attn_converter.decode(preds_index,
                                                       length_for_pred)
                preds_str = [pred[:pred.find('[s]')] for pred in preds_str]
        return preds_str[0]