Ejemplo n.º 1
0
    def __init__(self, model_path):
        ckpt = torch.load(model_path, map_location='cpu')
        cfg = ckpt['cfg']
        from config.rec_train_config import config
        self.model = build_model(config['model'])
        self.model = nn.DataParallel(self.model)
        self.model.load_state_dict(ckpt['state_dict'])

        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.model.to(device)
        self.process = RecDataProcess(cfg['dataset']['train']['dataset'])
        self.converter = CTCLabelConverter(cfg['dataset']['alphabet'])
Ejemplo n.º 2
0
    def __init__(self, model_path):
        ckpt = torch.load(model_path, map_location='cpu')
        cfg = ckpt['cfg']
        self.model = build_model(cfg['model'])
        state_dict = {}
        for k, v in ckpt['state_dict'].items():
            state_dict[k.replace('module.', '')] = v
        self.model.load_state_dict(state_dict)

        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)
        self.model.eval()

        self.process = RecDataProcess(cfg['dataset']['train']['dataset'])
        self.converter = CTCLabelConverter(cfg['dataset']['alphabet'])
Ejemplo n.º 3
0
class RecInfer:
    def __init__(self, model_path):
        ckpt = torch.load(model_path, map_location='cpu')
        cfg = ckpt['cfg']
        self.model = build_model(cfg['model'])
        state_dict = {}
        for k, v in ckpt['state_dict'].items():
            state_dict[k.replace('module.', '')] = v
        self.model.load_state_dict(state_dict)

        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)
        self.model.eval()

        self.process = RecDataProcess(cfg['dataset']['train']['dataset'])
        self.converter = CTCLabelConverter(cfg['dataset']['alphabet'])

    def predict(self, img):
        # 预处理根据训练来
        img = self.process.resize_with_specific_height(img)
        # img = self.process.width_pad_img(img, 120)
        img = self.process.normalize_img(img)
        tensor = torch.from_numpy(img.transpose([2, 0, 1])).float()
        tensor = tensor.unsqueeze(dim=0)
        tensor = tensor.to(self.device)
        out = self.model(tensor)
        txt = self.converter.decode(out.softmax(dim=2).detach().cpu().numpy())
        return txt
Ejemplo n.º 4
0
    def __init__(self, model_path):
        ckpt = torch.load(model_path, map_location='cpu')
        # with open('crnn_ckpt.txt',"w") as f:
        #     for k in ckpt['state_dict'].keys():
        #         f.write(k)
        #         f.write("---")
        #         f.write(str(ckpt['state_dict'][k].shape))
        #         f.write("\n")
        cfg = ckpt['cfg']
        self.model = build_model(cfg['model'])
        state_dict = {}
        for k, v in ckpt['state_dict'].items():
            state_dict[k.replace('module.', '')] = v
        self.model.load_state_dict(state_dict)

        self.device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)
        self.model.eval()

        self.process = RecDataProcess(cfg['dataset']['train']['dataset'])
        self.converter = CTCLabelConverter(cfg['dataset']['alphabet'])
        print(cfg['dataset']['alphabet'])
Ejemplo n.º 5
0
class RecInfer:
    def __init__(self, model_path, batch_size=16):
        ckpt = torch.load(model_path, map_location='cpu')
        cfg = ckpt['cfg']
        self.model = build_model(cfg['model'])
        state_dict = {}
        for k, v in ckpt['state_dict'].items():
            state_dict[k.replace('module.', '')] = v
        self.model.load_state_dict(state_dict)

        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')
        self.model.to(self.device)
        self.model.eval()

        self.process = RecDataProcess(cfg['dataset']['train']['dataset'])
        self.converter = CTCLabelConverter(cfg['dataset']['alphabet'])
        self.batch_size = batch_size

    def predict(self, imgs):
        # 预处理根据训练来
        if not isinstance(imgs, list):
            imgs = [imgs]
        imgs = [
            self.process.normalize_img(
                self.process.resize_with_specific_height(img)) for img in imgs
        ]
        widths = np.array([img.shape[1] for img in imgs])
        idxs = np.argsort(widths)
        txts = []
        for idx in range(0, len(imgs), self.batch_size):
            batch_idxs = idxs[idx:min(len(imgs), idx + self.batch_size)]
            batch_imgs = [
                self.process.width_pad_img(imgs[idx],
                                           imgs[batch_idxs[-1]].shape[1])
                for idx in batch_idxs
            ]
            batch_imgs = np.stack(batch_imgs)
            tensor = torch.from_numpy(batch_imgs.transpose([0, 3, 1,
                                                            2])).float()
            tensor = tensor.to(self.device)
            with torch.no_grad():
                out = self.model(tensor)
                out = out.softmax(dim=2)
            out = out.cpu().numpy()
            txts.extend(
                [self.converter.decode(np.expand_dims(txt, 0)) for txt in out])
        #按输入图像的顺序排序
        idxs = np.argsort(idxs)
        out_txts = [txts[idx] for idx in idxs]
        return out_txts
Ejemplo n.º 6
0
def train(net, optimizer, scheduler, loss_func, train_loader, eval_loader, to_use_device,
          cfg, global_state, logger):
    """
    训练函数

    :param net: 网络
    :param optimizer: 优化器
    :param scheduler: 学习率更新器
    :param loss_func: loss函数
    :param train_loader: 训练数据集 dataloader
    :param eval_loader: 验证数据集 dataloader
    :param to_use_device: device
    :param cfg: 当前训练所使用的配置
    :param global_state: 训练过程中的一些全局状态,如cur_epoch,cur_iter,最优模型的相关信息
    :param logger: logger 对象
    :return: None
    """

    from torchocr.metrics import RecMetric
    from torchocr.utils import CTCLabelConverter
    converter = CTCLabelConverter(cfg.dataset.alphabet)
    train_options = cfg.train_options
    metric = RecMetric(converter)
    # ===>
    logger.info('Training...')
    # ===> print loss信息的参数
    all_step = len(train_loader)
    logger.info(f'train dataset has {train_loader.dataset.__len__()} samples,{all_step} in dataloader')
    logger.info(f'eval dataset has {eval_loader.dataset.__len__()} samples,{len(eval_loader)} in dataloader')
    if len(global_state) > 0:
        best_model = global_state['best_model']
        start_epoch = global_state['start_epoch']
        global_step = global_state['global_step']
    else:
        best_model = {'best_acc': 0, 'eval_loss': 0, 'model_path': '', 'eval_acc': 0., 'eval_ned': 0.}
        start_epoch = 0
        global_step = 0
    # 开始训练
    try:
        for epoch in range(start_epoch, train_options['epochs']):  # traverse each epoch
            net.train()  # train mode
            start = time.time()
            for i, batch_data in enumerate(train_loader):  # traverse each batch in the epoch
                current_lr = optimizer.param_groups[0]['lr']
                cur_batch_size = batch_data['img'].shape[0]
                targets, targets_lengths = converter.encode(batch_data['label'])
                batch_data['targets'] = targets
                batch_data['targets_lengths'] = targets_lengths
                # 清零梯度及反向传播
                optimizer.zero_grad()
                output = net.forward(batch_data['img'].to(to_use_device))
                loss_dict = loss_func(output, batch_data)
                loss_dict['loss'].backward()
                torch.nn.utils.clip_grad_norm_(net.parameters(), 5)
                optimizer.step()
                # statistic loss for print
                acc_dict = metric(output, batch_data['label'])
                acc = acc_dict['n_correct'] / cur_batch_size
                norm_edit_dis = 1 - acc_dict['norm_edit_dis'] / cur_batch_size
                if (i + 1) % train_options['print_interval'] == 0:
                    interval_batch_time = time.time() - start
                    logger.info(f"[{epoch}/{train_options['epochs']}] - "
                                f"[{i + 1}/{all_step}] - "
                                f"lr:{current_lr} - "
                                f"loss:{loss_dict['loss'].item():.4f} - "
                                f"acc:{acc:.4f} - "
                                f"norm_edit_dis:{norm_edit_dis:.4f} - "
                                f"time:{interval_batch_time:.4f}")
                    start = time.time()
                if (i + 1) >= train_options['val_interval'] and (i + 1) % train_options['val_interval'] == 0:
                    global_state['start_epoch'] = epoch
                    global_state['best_model'] = best_model
                    global_state['global_step'] = global_step
                    net_save_path = f"{train_options['checkpoint_save_dir']}/latest.pth"
                    save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
                    if train_options['ckpt_save_type'] == 'HighestAcc':
                        # val
                        eval_dict = evaluate(net, eval_loader, loss_func, to_use_device, logger, converter, metric)
                        if eval_dict['eval_acc'] > best_model['eval_acc']:
                            best_model.update(eval_dict)
                            best_model['best_model_epoch'] = epoch
                            best_model['models'] = net_save_path

                            global_state['start_epoch'] = epoch
                            global_state['best_model'] = best_model
                            global_state['global_step'] = global_step
                            net_save_path = f"{train_options['checkpoint_save_dir']}/best.pth"
                            save_checkpoint(net_save_path, net, optimizer, logger, cfg, global_state=global_state)
                    elif train_options['ckpt_save_type'] == 'FixedEpochStep' and epoch % train_options['ckpt_save_epoch'] == 0:
                        shutil.copy(net_save_path, net_save_path.replace('latest.pth', f'{epoch}.pth'))
                global_step += 1
            scheduler.step()
    except KeyboardInterrupt:
        import os
        save_checkpoint(os.path.join(train_options['checkpoint_save_dir'], 'final.pth'), net, optimizer, logger, cfg, global_state=global_state)
    except:
        error_msg = traceback.format_exc()
        logger.error(error_msg)
    finally:
        for k, v in best_model.items():
            logger.info(f'{k}: {v}')
     lambda x: _resize_img_for_recognize(x, 32),
     transforms.Normalize(std=[1, 1, 1], mean=[0.5, 0.5, 0.5]),
     transforms.ToTensor(),
 ])
 detector_model_type = ''
 detector_config = AttrDict()
 recognizer_model_type = ''
 recognizer_config = AttrDict()
 detector_pretrained_model_file = ''
 recognizer_pretrained_model_file = ''
 annotate_on_image = True
 need_rectify_on_single_character = True
 labels = ''.join([f'{i}'
                   for i in range(10)] + [chr(97 + i) for i in range(26)])
 # 模型推断
 label_converter = CTCLabelConverter(labels)
 device = torch.device(device_name)
 detector = DetModel(detector_config).to(device)
 detector.load_state_dict(
     torch.load(detector_pretrained_model_file, map_location='cpu'))
 recognizer = RecModel(recognizer_config).to(device)
 recognizer.load_state_dict(
     torch.load(recognizer_pretrained_model_file, map_location='cpu'))
 detector.eval()
 recognizer.eval()
 with torch.no_grad():
     for m_path, m_pil_img, m_eval_tensor in tqdm(
             get_data(eval_dataset_directory, eval_file,
                      eval_detect_transformer)):
         m_eval_tensor = m_eval_tensor.to(device)
         # 获得检测需要的相关信息
Ejemplo n.º 8
0
def train(net, optimizer, scheduler, loss_func, train_loader, eval_loader, to_use_device,
          cfg, _epoch, logger):
    """
    Returns:
    :param net: 模型
    :param optimizer: 优化器
    :param scheduler: 学习率更新
    :param loss_func: loss函数
    :param train_loader: 训练数据集dataloader
    :param eval_loader: 验证数据集dataloader
    :param to_use_device:
    :param train_options:
    :param _epoch:
    :param logger:
    """
    from torchocr.metrics import RecMetric
    from torchocr.utils import CTCLabelConverter
    with open(cfg.dataset.alphabet, 'r', encoding='utf-8') as file:
        cfg.dataset.alphabet = ''.join([s.strip('\n') for s in file.readlines()])
    converter = CTCLabelConverter(cfg.dataset.alphabet)
    train_options = cfg.train_options
    metric = RecMetric(converter)
    # ===>
    logger.info('Training...')
    # ===> print loss信息的参数
    all_step = len(train_loader)
    logger.info(f'train dataset has {train_loader.dataset.__len__()} samples,{all_step} in dataloader')
    logger.info(f'eval dataset has {eval_loader.dataset.__len__()} samples,{len(eval_loader)} in dataloader')
    best_model = {'best_acc': 0, 'eval_loss': 0, 'model_path': '', 'eval_acc': 0., 'eval_ned': 0.}
    # 开始训练
    try:
        start = time.time()
        for epoch in range(_epoch, train_options['epochs']):  # traverse each epoch
            net.train()  # train mode
            for i, batch_data in enumerate(train_loader):  # traverse each batch in the epoch
                current_lr = optimizer.param_groups[0]['lr']
                cur_batch_size = batch_data['img'].shape[0]
                targets, targets_lengths = converter.encode(batch_data['label'])
                batch_data['targets'] = targets
                batch_data['targets_lengths'] = targets_lengths

                optimizer.zero_grad()
                output = net.forward(batch_data['img'].to(to_use_device))
                loss_dict = loss_func(output, batch_data)
                loss_dict['loss'].backward()
                torch.nn.utils.clip_grad_norm_(net.parameters(), 5)
                optimizer.step()
                # statistic loss for print
                acc_dict = metric(output, batch_data['label'])
                acc = acc_dict['n_correct'] / cur_batch_size
                norm_edit_dis = 1 - acc_dict['norm_edit_dis'] / cur_batch_size
                if (i + 1) % train_options['print_interval'] == 0:
                    interval_batch_time = time.time() - start
                    logger.info(f"[{epoch}/{train_options['epochs']}] - "
                                f"[{i + 1}/{all_step}] -"
                                f"lr:{current_lr} - "
                                f"loss:{loss_dict['loss'].item():.4f} - "
                                f"acc:{acc:.4f} - "
                                f"norm_edit_dis:{norm_edit_dis:.4f} - "
                                f"time:{interval_batch_time:.4f}")
                    start = time.time()
                if (i + 1) >= train_options['val_interval'] and (i + 1) % train_options['val_interval'] == 0:
                    # val
                    eval_dict = evaluate(net, eval_loader, loss_func, to_use_device, logger, converter, metric)
                    if train_options['ckpt_save_type'] == 'HighestAcc':
                        net_save_path = f"{train_options['checkpoint_save_dir']}/latest.pth"
                        save_checkpoint(net_save_path, net, optimizer, epoch, logger, cfg)
                        if eval_dict['eval_acc'] > best_model['eval_acc']:
                            best_model.update(eval_dict)
                            best_model['models'] = net_save_path
                            shutil.copy(net_save_path, net_save_path.replace('latest', 'best'))
                    elif train_options['ckpt_save_type'] == 'FixedEpochStep' and epoch % train_options['ckpt_save_epoch'] == 0:
                        net_save_path = f"{train_options['checkpoint_save_dir']}/{epoch}.pth"
                        save_checkpoint(net_save_path, net, optimizer, epoch, logger, cfg)
            scheduler.step()
    except KeyboardInterrupt:
        import os
        save_checkpoint(os.path.join(train_options['checkpoint_save_dir'], 'final_' + str(epoch) + '.pth'), net,
                        optimizer, epoch, logger, cfg)
    except:
        error_msg = traceback.format_exc(limit=1)
        logger.error(error_msg)
    finally:
        for k, v in best_model.items():
            logger.info(f'{k}: {v}')
Ejemplo n.º 9
0
    :param net: 网络
    :param optimizer: 优化器
    :param scheduler: 学习率更新器
    :param loss_func: loss函数
    :param train_loader: 训练数据集 dataloader
    :param eval_loader: 验证数据集 dataloader
    :param to_use_device: device
    :param cfg: 当前训练所使用的配置
    :param global_state: 训练过程中的一些全局状态,如cur_epoch,cur_iter,最优模型的相关信息
    :param logger: logger 对象
    :return: None
    """

    from torchocr.metrics import RecMetric
    from torchocr.utils import CTCLabelConverter
    converter = CTCLabelConverter(cfg.dataset.alphabet)
    train_options = cfg.train_options
    metric = RecMetric(converter)
    # ===>
    logger.info('Training...')
    # ===> print loss信息的参数
    all_step = len(train_loader)
    logger.info(f'train dataset has {train_loader.dataset.__len__()} samples,{all_step} in dataloader')
    logger.info(f'eval dataset has {eval_loader.dataset.__len__()} samples,{len(eval_loader)} in dataloader')
    if len(global_state) > 0:
        best_model = global_state['best_model']
        start_epoch = global_state['start_epoch']
        global_step = global_state['global_step']
    else:
        best_model = {'best_acc': 0, 'eval_loss': 0, 'model_path': '', 'eval_acc': 0., 'eval_ned': 0.}
        start_epoch = 0