示例#1
0
def save_checkpoint(model, optimizer, learning_rate, iteration, filepath):
    print("Saving model and optimizer state at iteration {} to {}".format(
        iteration, filepath))
    model_for_saving = CRNN(**CRNN_config).cuda()
    model_for_saving.load_state_dict(model.state_dict())
    torch.save(
        {
            'model': model_for_saving,
            'iteration': iteration,
            'optimizer': optimizer.state_dict(),
            'learning_rate': learning_rate
        }, filepath)
示例#2
0
    val_loader = torch.utils.data.DataLoader(testdataset,
                                             shuffle=False,
                                             batch_size=opt.batch_size,
                                             num_workers=int(opt.workers),
                                             collate_fn=alignCollate(
                                                 imgH=imgH,
                                                 imgW=imgW,
                                                 keep_ratio=keep_ratio))

    alphabet = keys.alphabetChinese
    print("char num ", len(alphabet))
    model = CRNN(32, 1, len(alphabet) + 1, 256, 1)

    converter = strLabelConverter(''.join(alphabet))

    state_dict = torch.load("../SceneOcr/model/ocr-lstm.pth",
                            map_location=lambda storage, loc: storage)
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k
        if "num_batches_tracked" not in k:
            # name = name.replace('module.', '')  # remove `module.`
            new_state_dict[name] = v
    model.cuda()
    model = torch.nn.DataParallel(model, device_ids=[0, 1, 2])

    # load params
    model.load_state_dict(new_state_dict)
    model.eval()

    curAcc = val(model, converter, val_loader, max_iter=5)
示例#3
0
import dataset
from PIL import Image
from models.crnn import CRNN


model_path = './data/crnn.pth'
img_path = './data/demo.png'
alphabet = '0123456789abcdefghijklmnopqrstuvwxyz'

model = CRNN(32, 1, 37, 256)

if torch.cuda.is_available():
    model = model.cuda()
print('loading pretrained model from %s' % model_path)

model.load_state_dict(torch.load(model_path))

converter = utils.strLabelConverter(alphabet)

transformer = dataset.resizeNormalize((100, 32))
image = Image.open(img_path).convert('L')
image = transformer(image)
if torch.cuda.is_available():
    image = image.cuda()
image = image.view(1, *image.size())
image = Variable(image)

model.eval()
preds = model(image)

_, preds = preds.max(2)
示例#4
0
class Demo(object):
    def __init__(self, args):
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
        self.args = args
        self.alphabet = alphabetChinese
        nclass = len(self.alphabet) + 1
        nc = 1
        self.net = CRNN(args.imgH, nc, args.nh, nclass)
        self.converter = utils.strLabelConverter(self.alphabet, ignore_case=False)
        self.transformer = resizeNormalize(args.imgH)

        print('loading pretrained model from %s' % args.model_path)
        checkpoint = torch.load(args.model_path)
        if 'model_state_dict' in checkpoint.keys():
            checkpoint = checkpoint['model_state_dict']
        from collections import OrderedDict
        model_dict = OrderedDict()
        for k, v in checkpoint.items():
            if 'module' in k:
                model_dict[k[7:]] = v
            else:
                model_dict[k] = v
        self.net.load_state_dict(model_dict)

        if args.cuda and torch.cuda.is_available():
            print('available gpus is,', torch.cuda.device_count())
            self.net = torch.nn.DataParallel(self.net, output_dim=1).cuda()
        
        self.net.eval()
    
    def predict(self, image):
        image = self.transformer(image)
        if torch.cuda.is_available():
            image = image.cuda()
        image = image.view(1, *image.size())
        image = Variable(image)

        preds = self.net(image)
        _, preds = preds.max(2)
        preds = preds.transpose(1, 0).contiguous().view(-1)
        preds_size = Variable(torch.IntTensor([preds.size(0)]))
        raw_pred = self.converter.decode(preds.data, preds_size.data, raw=True)
        sim_pred = self.converter.decode(preds.data, preds_size.data, raw=False)
        print('%-20s => %-20s' % (raw_pred, sim_pred))

        return sim_pred

    def predict_batch(self, images):
        N = len(images)
        n_batch = N // self.args.batch_size
        n_batch += 1 if N % self.args.batch_size else 0
        res = []
        for i in range(n_batch):
            batch = images[i*self.args.batch_size : min((i+1)*self.args.batch_size, N)]
            maxW = 0
            for i in range(len(batch)):
                batch[i] = self.transformer(batch[i])
                imgW = batch[i].shape[2]
                maxW = max(maxW, imgW)
            
            for i in range(len(batch)):
                if batch[i].shape[2] < maxW:
                    batch[i] = torch.cat((batch[i], torch.zeros((1, self.args.imgH, maxW-batch[i].shape[2]), dtype=batch[i].dtype)), 2) 
            batch_imgs = torch.cat([t.unsqueeze(0) for t in batch], 0)
            preds = self.net(batch_imgs)
            preds_size = Variable(torch.IntTensor([preds.size(0)]*len(batch)))
            _, preds = preds.max(2)
            preds = preds.transpose(1, 0).contiguous().view(-1)
            raw_preds = self.converter.decode(preds.data, preds_size.data, raw=True)
            sim_preds = self.converter.decode(preds.data, preds_size.data, raw=False)
            for raw_pred, sim_pred in zip(raw_preds, sim_preds):
                print('%-20s => %-20s' % (raw_pred, sim_pred))
            res.extend(sim_preds)
        return res

    def inference(self, image_path, batch_pred=False):
        if os.path.isdir(image_path):
            file_list = os.listdir(image_path)
            image_list = [os.path.join(image_path, i) for i in file_list if i.rsplit('.')[-1].lower() in img_types] 
        else:
            image_list = [image_path]
        
        res = []
        images = []
        for img_path in image_list:
            image = Image.open(img_path).convert('L')
            if not batch_pred:
                sim_pred = self.predict(image)
                res.append(sim_pred)
            else:
                images.append(image)
        if batch_pred and images:
            res = self.predict_batch(images)
        return res
示例#5
0
    # Load SSD model
    PATH_TO_FROZEN_GRAPH = args.detection_model_path
    detection_graph = tf.Graph()
    with detection_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as f:
            od_graph_def.ParseFromString(f.read())
            tf.import_graph_def(od_graph_def, name='')

    # Load CRNN model
    alphabet = '0123456789abcdefghijklmnopqrstuvwxyz'
    crnn = CRNN(32, 1, 37, 256)
    if torch.cuda.is_available():
        crnn = crnn.cuda()
    crnn.load_state_dict(torch.load(args.recognition_model_path))
    converter = utils.strLabelConverter(alphabet)
    transformer = dataset.resizeNormalize((100, 32))
    crnn.eval()

    # Open a video file or an image file
    cap = cv2.VideoCapture(args.input if args.input else 0)

    while cv2.waitKey(1) < 0:
        has_frame, frame = cap.read()
        if not has_frame:
            cv2.waitKey(0)
            break

        im_height, im_width, _ = frame.shape
        tf_frame = np.expand_dims(frame, axis=0)
示例#6
0
def main():
    # Set parameters of the trainer
    global args, device
    args = parse_args()

    print('=' * 60)
    print(args)
    print('=' * 60)
    random.seed(args.manual_seed)
    np.random.seed(args.manual_seed)
    torch.manual_seed(args.manual_seed)

    if torch.cuda.is_available() and not args.cuda:
        print(
            "WARNING: You have a CUDA device, so you should probably run with --cuda"
        )
    if args.cuda and torch.cuda.is_available():
        device = torch.device('cuda')
        #cudnn.benchmark = True
    else:
        device = torch.device('cpu')

    # load alphabet from file
    if os.path.isfile(args.alphabet):
        alphabet = ''
        with open(args.alphabet, mode='rb') as f:
            for line in f.readlines():
                alphabet += line.decode('utf-8')[0]
        args.alphabet = alphabet

    converter = utils.CTCLabelConverter(args.alphabet, ignore_case=False)

    # data loader
    image_size = (args.image_h, args.image_w)
    collater = DatasetCollater(image_size, keep_ratio=args.keep_ratio)
    train_dataset = Dataset(mode='train',
                            data_root=args.data_root,
                            transform=None)
    #sampler = RandomSequentialSampler(train_dataset, args.batch_size)
    train_loader = data.DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   collate_fn=collater,
                                   shuffle=True,
                                   num_workers=args.workers)

    val_dataset = Dataset(mode='val', data_root=args.data_root, transform=None)
    val_loader = data.DataLoader(val_dataset,
                                 batch_size=args.batch_size,
                                 collate_fn=collater,
                                 shuffle=True,
                                 num_workers=args.workers)

    # network
    num_classes = len(args.alphabet) + 1
    num_channels = 1
    if args.arch == 'crnn':
        model = CRNN(args.image_h, num_channels, num_classes, args.num_hidden)
    elif args.arch == 'densenet':
        model = DenseNet(
            num_channels=num_channels,
            num_classes=num_classes,
            growth_rate=12,
            block_config=(3, 6, 9),  #(3,6,12,16),
            compression=0.5,
            num_init_features=64,
            bn_size=4,
            rnn=args.rnn,
            num_hidden=args.num_hidden,
            drop_rate=0,
            small_inputs=True,
            efficient=False)
    else:
        raise ValueError('unknown architecture {}'.format(args.arch))
    model = model.to(device)
    summary(model, torch.zeros((2, 1, 32, 650)).to(device))
    #print('='*60)
    #print(model)
    #print('='*60)

    # loss
    criterion = CTCLoss()
    criterion = criterion.to(device)

    # setup optimizer
    if args.optimizer == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.lr,
                              momentum=0.9,
                              weight_decay=args.weight_decay)
    elif args.optimizer == 'rmsprop':
        optimizer = optim.RMSprop(model.parameters(),
                                  lr=args.lr,
                                  weight_decay=args.weight_decay)
    elif args.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.lr,
                               betas=(0.9, 0.999),
                               weight_decay=args.weight_decay)
    elif args.optimizer == 'adadelta':
        optimizer = optim.Adadelta(model.parameters(),
                                   weight_decay=args.weight_decay)
    else:
        raise ValueError('unknown optimizer {}'.format(args.optimizer))
    print('=' * 60)
    print(optimizer)
    print('=' * 60)

    # Define learning rate decay schedule
    global scheduler
    #exp_decay = math.exp(-0.1)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer,
                                                 gamma=args.decay_rate)
    #step_size = 10000
    #gamma_decay = 0.8
    #scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma_decay)
    #scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=gamma_decay)

    # initialize model
    if args.pretrained and os.path.isfile(args.pretrained):
        print(">> Using pre-trained model '{}'".format(
            os.path.basename(args.pretrained)))
        state_dict = torch.load(args.pretrained)
        model.load_state_dict(state_dict)
        print("loading pretrained model done.")

    global is_best, best_accuracy
    is_best = False
    best_accuracy = 0.0
    start_epoch = 0
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            # load checkpoint weights and update model and optimizer
            print(">> Loading checkpoint:\n>> '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            best_accuracy = checkpoint['best_accuracy']
            print(">>>> loaded checkpoint:\n>>>> '{}' (epoch {})".format(
                args.resume, start_epoch))
            model.load_state_dict(checkpoint['state_dict'])
            #optimizer.load_state_dict(checkpoint['optimizer'])
            # important not to forget scheduler updating
            #scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=args.decay_rate, last_epoch=start_epoch - 1)
        else:
            print(">> No checkpoint found at '{}'".format(args.resume))

    # Create export dir if it doesnt exist
    checkpoint = "{}".format(args.arch)
    checkpoint += "_{}".format(args.optimizer)
    checkpoint += "_lr_{}".format(args.lr)
    checkpoint += "_decay_rate_{}".format(args.decay_rate)
    checkpoint += "_bsize_{}".format(args.batch_size)
    checkpoint += "_height_{}".format(args.image_h)
    checkpoint += "_keep_ratio" if args.keep_ratio else "_width_{}".format(
        image_size[1])

    args.checkpoint = os.path.join(args.checkpoint, checkpoint)
    if not os.path.exists(args.checkpoint):
        os.makedirs(args.checkpoint)

    print('start training...')
    for epoch in range(start_epoch, args.max_epoch):
        # Aujust learning rate for each epoch
        scheduler.step()

        # Train for one epoch on train set
        _ = train(train_loader, val_loader, model, criterion, optimizer, epoch,
                  converter)
示例#7
0
def main():
    config = Config()

    if not os.path.exists(config.expr_dir):
        os.makedirs(config.expr_dir)

    if torch.cuda.is_available() and not config.use_cuda:
        print("WARNING: You have a CUDA device, so you should probably set cuda in params.py to True")

    # 加载训练数据集
    train_dataset = HubDataset(config, "train", transform=None)

    train_kwargs = {'num_workers': 2, 'pin_memory': True,
                    'collate_fn': alignCollate(config.img_height, config.img_width, config.keep_ratio)} if torch.cuda.is_available() else {}

    training_data_batch = DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True, drop_last=False, **train_kwargs)

    # 加载定长校验数据集
    eval_dataset = HubDataset(config, "eval", transform=transforms.Compose([ResizeNormalize(config.img_height, config.img_width)]))
    eval_kwargs = {'num_workers': 2, 'pin_memory': False} if torch.cuda.is_available() else {}
    eval_data_batch = DataLoader(eval_dataset, batch_size=config.eval_batch_size, shuffle=False, drop_last=False, **eval_kwargs)

    # 加载不定长校验数据集
    # eval_dataset = HubDataset(config, "eval")
    # eval_kwargs = {'num_workers': 2, 'pin_memory': False,
    #                'collate_fn': alignCollate(config.img_height, config.img_width, config.keep_ratio)} if torch.cuda.is_available() else {}
    # eval_data_batch = DataLoader(eval_dataset, batch_size=config.eval_batch_size, shuffle=False, drop_last=False, **eval_kwargs)

    # 定义网络模型
    nclass = len(config.label_classes) + 1
    crnn = CRNN(config.img_height, config.nc, nclass, config.hidden_size, n_rnn=config.n_layers)
    # 加载预训练模型
    if config.pretrained != '':
        print('loading pretrained model from %s' % config.pretrained)
        crnn.load_state_dict(torch.load(config.pretrained))
    print(crnn)

    # Compute average for `torch.Variable` and `torch.Tensor`.
    loss_avg = utils.averager()

    # Convert between str and label.
    converter = utils.strLabelConverter(config.label_classes)

    criterion = CTCLoss()           # 定义损失函数

    # 设置占位符
    image = torch.FloatTensor(config.train_batch_size, 3, config.img_height, config.img_height)
    text = torch.LongTensor(config.train_batch_size * 5)
    length = torch.LongTensor(config.train_batch_size)

    if config.use_cuda and torch.cuda.is_available():
        criterion = criterion.cuda()
        image = image.cuda()
        crnn = crnn.to(config.device)

    image = Variable(image)
    text = Variable(text)
    length = Variable(length)

    # 设定优化器
    if config.adam:
        optimizer = optim.Adam(crnn.parameters(), lr=config.lr, betas=(config.beta1, 0.999))
    elif config.adadelta:
        optimizer = optim.Adadelta(crnn.parameters())
    else:
        optimizer = optim.RMSprop(crnn.parameters(), lr=config.lr)

    def val(net, criterion, eval_data_batch):
        print('Start val')
        for p in crnn.parameters():
            p.requires_grad = False
        net.eval()

        n_correct = 0
        loss_avg_eval = utils.averager()
        for data in eval_data_batch:
            cpu_images, cpu_texts = data
            batch_size = cpu_images.size(0)
            utils.loadData(image, cpu_images)
            t, l = converter.encode(cpu_texts)
            utils.loadData(text, t)
            utils.loadData(length, l)
            preds = crnn(image)
            preds_size = Variable(torch.LongTensor([preds.size(0)] * batch_size))
            cost = criterion(preds, text, preds_size, length) / batch_size
            loss_avg_eval.add(cost)         # 计算loss

            _, preds = preds.max(2)
            preds = preds.transpose(1, 0).contiguous().view(-1)
            sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
            cpu_texts_decode = []
            for i in cpu_texts:
                cpu_texts_decode.append(i)
            for pred, target in zip(sim_preds, cpu_texts_decode):       # 计算准确率
                if pred == target:
                    n_correct += 1

            raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:config.n_val_disp]
            for raw_pred, pred, gt in zip(raw_preds, sim_preds, cpu_texts_decode):
                print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))

        accuracy = n_correct / float(len(eval_dataset))
        print('Val loss: %f, accuray: %f' % (loss_avg.val(), accuracy))

    # 训练每个batch数据
    def train(net, criterion, optimizer, data):
        cpu_images, cpu_texts = data
        batch_size = cpu_images.size(0)             # 计算当前batch_size大小
        utils.loadData(image, cpu_images)
        t, l = converter.encode(cpu_texts)          # 转换为类别
        utils.loadData(text, t)
        utils.loadData(length, l)
        optimizer.zero_grad()                       # 清零梯度
        preds = net(image)
        preds_size = Variable(torch.LongTensor([preds.size(0)] * batch_size))
        cost = criterion(preds, text, preds_size, length) / batch_size
        cost.backward()
        optimizer.step()
        return cost

    for epoch in range(config.nepoch):
        i = 0
        for batch_data in training_data_batch:
            for p in crnn.parameters():
                p.requires_grad = True
            crnn.train()
            cost = train(crnn, criterion, optimizer, batch_data)
            loss_avg.add(cost)
            i += 1

            if i % config.displayInterval == 0:
                print('[%d/%d][%d/%d] Loss: %f' %
                      (epoch, config.nepoch, i, len(training_data_batch), loss_avg.val()))
                loss_avg.reset()

            # if i % config.valInterval == 0:
            #     val(crnn, criterion, eval_data_batch)
            #
            # # do checkpointing
            # if i % config.saveInterval == 0:
            #     torch.save(crnn.state_dict(), '{0}/netCRNN_{1}_{2}.pth'.format(config.expr_dir, epoch, i))

        val(crnn, criterion, eval_data_batch)
        torch.save(crnn.state_dict(), '{0}/netCRNN_{1}_end.pth'.format(config.expr_dir, epoch))
    random.seed(opt.manualSeed)
    np.random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)

    transformer = dataset.resizeNormalize((100, 32))

    nclass = len(opt.alphabet) + 1
    nc = 1

    converter = misc.strLabelConverter(opt.alphabet)

    crnn = CRNN(opt.imgH, nc, nclass, opt.nh)
    if opt.pretrained != '':
        print('loading pretrained model from %s' % opt.pretrained)
        crnn.load_state_dict(load_multi(opt.pretrained), strict=False)

    # Process pruned conv2d and batchnorm2d, store them in a dictionary
    crnn_l = list(crnn.cnn._modules.items())

    last_channels = [0]

    crnn_new = CRNN(opt.imgH, nc, nclass, opt.nh)
    crnn_new = copy.deepcopy(crnn)
    new_dict = {}

    for i in range(len(crnn_l)):
        module = crnn_l[i][1]
        if isinstance(module, torch.nn.Conv2d):
            out_channels = get_out_channel(module)
            new = set_weight_conv(last_channels, out_channels, module)
示例#9
0
class Trainer(object):
    def __init__(self):
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpus
        if args.chars_file == '':
            self.alphabet = alphabetChinese
        else:
            self.alphabet = utils.load_chars(args.chars_file)
        nclass = len(self.alphabet) + 1
        nc = 1
        self.net = CRNN(args.imgH, nc, args.nh, nclass)
        self.train_dataloader, self.val_dataloader = self.dataloader(
            self.alphabet)
        self.criterion = CTCLoss()
        self.optimizer = self.get_optimizer()
        self.converter = utils.strLabelConverter(self.alphabet,
                                                 ignore_case=False)
        self.best_acc = 0.00001

        model_name = '%s' % (args.dataset_name)
        if not os.path.exists(args.save_prefix):
            os.mkdir(args.save_prefix)
        args.save_prefix += model_name

        if args.pretrained != '':
            print('loading pretrained model from %s' % args.pretrained)
            checkpoint = torch.load(args.pretrained)

            if 'model_state_dict' in checkpoint.keys():
                # self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                args.start_epoch = checkpoint['epoch']
                self.best_acc = checkpoint['best_acc']
                checkpoint = checkpoint['model_state_dict']

            from collections import OrderedDict
            model_dict = OrderedDict()
            for k, v in checkpoint.items():
                if 'module' in k:
                    model_dict[k[7:]] = v
                else:
                    model_dict[k] = v
            self.net.load_state_dict(model_dict)

        if not args.cuda and torch.cuda.is_available():
            print(
                "WARNING: You have a CUDA device, so you should probably run with --cuda"
            )

        elif args.cuda and torch.cuda.is_available():
            print('available gpus is ', torch.cuda.device_count())
            self.net = torch.nn.DataParallel(self.net, output_dim=1).cuda()
            self.criterion = self.criterion.cuda()

    def dataloader(self, alphabet):
        # train_transform = transforms.Compose(
        #     [transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
        #     resizeNormalize(args.imgH)])
        # train_dataset = BaseDataset(args.train_dir, alphabet, transform=train_transform)
        train_dataset = NumDataset(args.train_dir,
                                   alphabet,
                                   transform=resizeNormalize(args.imgH))
        train_dataloader = DataLoader(dataset=train_dataset,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.num_workers,
                                      pin_memory=True)

        if os.path.exists(args.val_dir):
            # val_dataset = BaseDataset(args.val_dir, alphabet, transform=resizeNormalize(args.imgH))
            val_dataset = NumDataset(args.val_dir,
                                     alphabet,
                                     mode='test',
                                     transform=resizeNormalize(args.imgH))
            val_dataloader = DataLoader(dataset=val_dataset,
                                        batch_size=args.batch_size,
                                        shuffle=False,
                                        num_workers=args.num_workers,
                                        pin_memory=True)
        else:
            val_dataloader = None

        return train_dataloader, val_dataloader

    def get_optimizer(self):
        if args.optimizer == 'sgd':
            optimizer = optim.SGD(
                self.net.parameters(),
                lr=args.lr,
                momentum=args.momentum,
                weight_decay=args.wd,
            )
        elif args.optimizer == 'adam':
            optimizer = optim.Adam(
                self.net.parameters(),
                lr=args.lr,
                betas=(args.beta1, 0.999),
            )
        else:
            optimizer = optim.RMSprop(
                self.net.parameters(),
                lr=args.lr,
                momentum=args.momentum,
                weight_decay=args.wd,
            )
        return optimizer

    def train(self):
        logging.basicConfig()
        logger = logging.getLogger()
        logger.setLevel(logging.INFO)
        log_file_path = args.save_prefix + '_train.log'
        log_dir = os.path.dirname(log_file_path)
        if log_dir and not os.path.exists(log_dir):
            os.mkdir(log_dir)
        fh = logging.FileHandler(log_file_path)
        logger.addHandler(fh)
        logger.info(args)
        logger.info('Start training from [Epoch {}]'.format(args.start_epoch +
                                                            1))

        losses = utils.Averager()
        train_accuracy = utils.Averager()

        for epoch in range(args.start_epoch, args.nepoch):
            self.net.train()
            btic = time.time()
            for i, (imgs, labels) in enumerate(self.train_dataloader):
                batch_size = imgs.size()[0]
                imgs = imgs.cuda()
                preds = self.net(imgs).cpu()
                text, length = self.converter.encode(
                    labels
                )  # length  一个batch各个样本的字符长度, text 一个batch中所有中文字符所对应的下标
                preds_size = torch.IntTensor([preds.size(0)] * batch_size)
                loss_avg = self.criterion(preds, text, preds_size,
                                          length) / batch_size

                self.optimizer.zero_grad()
                loss_avg.backward()
                self.optimizer.step()

                losses.update(loss_avg.item(), batch_size)

                _, preds_m = preds.max(2)
                preds_m = preds_m.transpose(1, 0).contiguous().view(-1)
                sim_preds = self.converter.decode(preds_m.data,
                                                  preds_size.data,
                                                  raw=False)
                n_correct = 0
                for pred, target in zip(sim_preds, labels):
                    if pred == target:
                        n_correct += 1
                train_accuracy.update(n_correct, batch_size, MUL_n=False)

                if args.log_interval and not (i + 1) % args.log_interval:
                    logger.info(
                        '[Epoch {}/{}][Batch {}/{}], Speed: {:.3f} samples/sec, Loss:{:.3f}'
                        .format(epoch + 1, args.nepoch, i + 1,
                                len(self.train_dataloader),
                                batch_size / (time.time() - btic),
                                losses.val()))
                    losses.reset()

            logger.info(
                'Training accuracy: {:.3f}, [#correct:{} / #total:{}]'.format(
                    train_accuracy.val(), train_accuracy.sum,
                    train_accuracy.count))
            train_accuracy.reset()

            if args.val_interval and not (epoch + 1) % args.val_interval:
                acc = self.validate(logger)
                if acc > self.best_acc:
                    self.best_acc = acc
                    save_path = '{:s}_best.pth'.format(args.save_prefix)
                    torch.save(
                        {
                            'epoch': epoch,
                            'model_state_dict': self.net.state_dict(),
                            # 'optimizer_state_dict': self.optimizer.state_dict(),
                            'best_acc': self.best_acc,
                        },
                        save_path)
                logging.info("best acc is:{:.3f}".format(self.best_acc))
                if args.save_interval and not (epoch + 1) % args.save_interval:
                    save_path = '{:s}_{:04d}_{:.3f}.pth'.format(
                        args.save_prefix, epoch + 1, acc)
                    torch.save(
                        {
                            'epoch': epoch,
                            'model_state_dict': self.net.state_dict(),
                            # 'optimizer_state_dict': self.optimizer.state_dict(),
                            'best_acc': self.best_acc,
                        },
                        save_path)

    def validate(self, logger):
        if self.val_dataloader is None:
            return 0
        logger.info('Start validate.')
        losses = utils.Averager()
        self.net.eval()
        n_correct = 0
        with torch.no_grad():
            for i, (imgs, labels) in enumerate(self.val_dataloader):
                batch_size = imgs.size()[0]
                imgs = imgs.cuda()
                preds = self.net(imgs).cpu()
                text, length = self.converter.encode(
                    labels
                )  # length  一个batch各个样本的字符长度, text 一个batch中所有中文字符所对应的下标
                preds_size = torch.IntTensor(
                    [preds.size(0)] * batch_size)  # timestep * batchsize
                loss_avg = self.criterion(preds, text, preds_size,
                                          length) / batch_size

                losses.update(loss_avg.item(), batch_size)

                _, preds = preds.max(2)
                preds = preds.transpose(1, 0).contiguous().view(-1)
                sim_preds = self.converter.decode(preds.data,
                                                  preds_size.data,
                                                  raw=False)
                for pred, target in zip(sim_preds, labels):
                    if pred == target:
                        n_correct += 1

        accuracy = n_correct / float(losses.count)

        logger.info(
            'Evaling loss: {:.3f}, accuracy: {:.3f}, [#correct:{} / #total:{}]'
            .format(losses.val(), accuracy, n_correct, losses.count))

        return accuracy