コード例 #1
0
    def __init__(self, config):
        """
        :param config: dataset config, need data_dir, input_h, mean, std,
        mode:'train' or 'val', augmentation: True or False,
        batch_size, shuffle, num_workers
        """
        self.config = config
        self.data_dir = config.data_dir
        self.input_h = config.input_h
        self.mode = config.mode
        self.alphabet = config.alphabet
        self.mean = np.array(config.mean, dtype=np.float32)
        self.std = np.array(config.std, dtype=np.float32)
        self.augmentation = config.augmentation

        # get alphabet
        with open(self.alphabet, 'r') as file:
            alphabet = ''.join([s.strip('\n') for s in file.readlines()])
        # get converter
        self.converter = StrLabelConverter(alphabet, False)

        # build path of train.txt of val.txt
        gt_path = os.path.join(self.data_dir, f'{self.mode}.txt')
        with open(gt_path, 'r', encoding='utf-8') as file:
            # build {img_path: trans}
            self.labels = []
            for m_line in file:
                m_image_name, m_gt_text = m_line.strip().split('\t')
                self.labels.append((m_image_name, m_gt_text))

        print(f'load {self.__len__()} images.')
コード例 #2
0
class icdar15RecDataset(Dataset):
    def __init__(self, config):
        '''
        :param config: dataset config, need data_dir, input_h, mean, std,
        mode:'train' or 'val', augmentation: True or False,
        batch_size, shuffle, num_workers
        '''
        self.config = config
        self.data_dir = config.data_dir
        self.input_h = config.input_h
        self.mode = config.mode
        self.alphabet = config.alphabet
        self.mean = np.array(config.mean, dtype=np.float32)
        self.std = np.array(config.std, dtype=np.float32)
        self.augmentation = config.augmentation

        # get alphabet
        with open(self.alphabet, 'r') as file:
            alphabet = ''.join([s.strip('\n') for s in file.readlines()])
        # get converter
        self.converter = StrLabelConverter(alphabet, False)

        # build path of train.txt of val.txt
        gt_path = os.path.join(self.data_dir, f'{self.mode}.txt')
        with open(gt_path, 'r', encoding='utf-8') as file:
            # build {img_path: trans}
            self.labels = [{
                line.split('\t')[0]: line.split('\t')[-1][:-1]
            } for line in file.readlines()]

        print(f'load {self.__len__()} images.')

    def _findmaxlength(self):
        return max({len(list(d.values())[0]) for d in self.labels})

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        # get img_path
        img_name = list(self.labels[index].keys())[0]
        img_path = os.path.join(self.data_dir, f'images/{img_name}')

        # get trans
        trans = list(self.labels[index].values())[0]
        # convert to label
        label, length = self.converter.encode(trans)
        # read img
        img = cv2.imread(img_path)
        # do aug
        if self.augmentation:
            img = pil2cv(RecDataProcess(self.config).aug_img(cv2pil(img)))
        # to gray
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

        return img, label, length
コード例 #3
0
class ICDAR15RecDataset(Dataset):
    def __init__(self, config):
        """
        :param config: dataset config, need data_dir, input_h, mean, std,
        mode:'train' or 'val', augmentation: True or False,
        batch_size, shuffle, num_workers
        """
        self.config = config
        self.data_dir = config.data_dir
        self.input_h = config.input_h
        self.mode = config.mode
        self.alphabet = config.alphabet
        self.mean = np.array(config.mean, dtype=np.float32)
        self.std = np.array(config.std, dtype=np.float32)
        self.augmentation = config.augmentation

        # get alphabet
        with open(self.alphabet, 'r') as file:
            alphabet = ''.join([s.strip('\n') for s in file.readlines()])
        # get converter
        self.converter = StrLabelConverter(alphabet, False)

        # build path of train.txt of val.txt
        gt_path = os.path.join(self.data_dir, f'{self.mode}.txt')
        with open(gt_path, 'r', encoding='utf-8') as file:
            # build {img_path: trans}
            self.labels = []
            for m_line in file:
                m_image_name, m_gt_text = m_line.strip().split('\t')
                self.labels.append((m_image_name, m_gt_text))

        print(f'load {self.__len__()} images.')

    def _find_max_length(self):
        return max({_[1] for _ in self.labels})

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        # get img_path and trans
        img_name, trans = self.labels[index]
        img_path = os.path.join(self.data_dir, 'images', img_name)

        # convert to label
        label, length = self.converter.encode(trans)
        # read img
        img = cv2.imread(img_path)
        # do aug
        if self.augmentation:
            img = pil2cv(RecDataProcess(self.config).aug_img(cv2pil(img)))
        # to gray
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

        return img, label, length
コード例 #4
0
                torch.save(model, weights_dir + 'weights_'+ file_name + '_lr_' + str(lr) + '_num_classes_' + str(nclass) + \
                        '_batch_size_' + str(batch_size) + '.pt')

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best test Acc: {:4f}'.format(best_acc))

    model = torch.load(weights_dir + 'weights_'+ file_name + '_lr_' + str(lr) + '_num_classes_' + str(nclass) + \
            '_batch_size_' + str(batch_size) + '.pt')

    return model


#Dataload and generator initialization
converter = StrLabelConverter(''.join(class_map.keys()) + ' ')
image_datasets = {
    'train': EnglishImagePreloader(data_dir + train_csv),
    'test': EnglishImagePreloader(data_dir + test_csv)
}
dataloaders = {
    x: torch.utils.data.DataLoader(image_datasets[x],
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=4)
    for x in ['train', 'test']
}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
file_name = __file__.split('/')[-1].split('.')[0]

#Create model and initialize/freeze weights
コード例 #5
0
 eval_recognize_transformer = transforms.Compose([
     transforms.ToPILImage(),
     lambda x: _resize_img_for_recognize(x, 32),
     transforms.Normalize(std=[1, 1, 1], mean=[0.5, 0.5, 0.5]),
     transforms.ToTensor(),
 ])
 detector_model_type = 'db'
 recognizer_model_type = 'crnn_res'
 detector_pretrained_model_file = ''
 recognizer_pretrained_model_file = ''
 annotate_on_image = True
 need_rectify_on_single_character = True
 labels = ''.join([f'{i}'
                   for i in range(10)] + [chr(97 + i) for i in range(26)])
 # 模型推断
 label_converter = StrLabelConverter(labels)
 device = torch.device(device_name)
 detector = get_detector(detector_model_type,
                         detector_model_type).to(device)
 recognizer = get_recognizer(recognizer_model_type,
                             recognizer_pretrained_model_file).to(device)
 detector.eval()
 recognizer.eval()
 with torch.no_grad():
     for m_path, m_pil_img, m_eval_tensor in tqdm(
             get_data(eval_dataset_directory, eval_file,
                      eval_detect_transformer)):
         m_eval_tensor = m_eval_tensor.to(device)
         # 获得检测需要的相关信息
         m_detect_result = detector(m_eval_tensor)
         # 根据网络类型,处理检测的相关信息,最后转换为一堆多边形