예제 #1
0
def get_key(label_file_list, ignore_chinese_punctuation, show_max_img=False):
    data_list = []
    label_list = []
    max_len = 0
    max_h = 0
    max_w = 0
    for label_path in label_file_list:
        with open(label_path, 'r', encoding='utf-8') as f:
            for line in tqdm(f.readlines(), desc=label_path):
                line = line.strip('\n').replace('.jpg ', '.jpg\t').replace(
                    '.png ', '.png\t').split('\t')
                if len(line) > 1:
                    data_list.append(line[0])
                    label = line[1]
                    if ignore_chinese_punctuation:
                        label = punctuation_mend(label)
                    label_list.append(label)
                    max_len = max(max_len, len(line[1]))
                    if show_max_img:
                        img = cv2.imread(line[0])
                        h, w = img.shape[:2]
                        max_h = max(max_h, h)
                        max_w = max(max_w, w)
    if show_max_img:
        print(
            'max len of label is {}, max img_h is {}, max img_w is {}'.format(
                max_len, max_h, max_w))
    a = ''.join(sorted(set((''.join(label_list)))))
    return a
예제 #2
0
def get_key(label_file_list, ignore_chinese_punctuation, show_max_img=False):
    data_list = []
    label_list = []
    len_dict = defaultdict(int)
    h_dict = defaultdict(int)
    w_dict = defaultdict(int)
    for label_path in label_file_list:
        with open(label_path, 'r', encoding='utf-8') as f:
            for line in tqdm(f.readlines(), desc=label_path):
                line = line.strip('\n').replace('.jpg ', '.jpg\t').replace(
                    '.png ', '.png\t').split('\t')
                if len(line) > 1 and os.path.exists(line[0]):
                    data_list.append(line[0])
                    label = line[1]
                    if ignore_chinese_punctuation:
                        label = punctuation_mend(label)
                    label_list.append(label)
                    len_dict[len(line[1])] += 1
                    if show_max_img:
                        img = cv2.imread(line[0])
                        h, w = img.shape[:2]
                        h_dict[h] += 1
                        w_dict[w] += 1
    if show_max_img:
        print('******************分析宽度******************')
        show_dict(w_dict, 10, 'w')
        print('******************分析高度******************')
        show_dict(h_dict, 1, 'h')
        print('******************分析label长度******************')
        show_dict(len_dict, 1, 'label')
    a = ''.join(sorted(set((''.join(label_list)))))
    return a
예제 #3
0
 def get_sample(self, index):
     img_path, label = self.data_list[index]
     img = cv2.imread(img_path, 1 if self.img_mode != 'GRAY' else 0)
     if self.img_mode == 'RGB':
         img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
     if self.ignore_chinese_punctuation:
         label = punctuation_mend(label)
     if self.remove_blank:
         label = label.replace(' ', '')
     return {'img': img, 'label': label}
예제 #4
0
 def __getitem__(self, idx):
     img_path, label = self.data_list[idx]
     img = image.imread(img_path,
                        1 if self.img_channel == 3 else 0).asnumpy()
     label = label.replace(' ', '')
     if self.ignore_chinese_punctuation:
         label = punctuation_mend(label)
     try:
         label = self.label_enocder(label)
     except:
         logger.error('meet error when encode label, {},{}'.format(
             img_path, label))
     if self.phase == 'train':
         img = seq.augment_image(img)
     img = nd.array(img)
     img = self.pre_processing(img)
     return img, label
예제 #5
0
    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        index = self.filtered_index_list[index]

        with self.env.begin(write=False) as txn:
            label_key = 'label-%09d'.encode() % index
            label = txn.get(label_key).decode('utf-8')
            img_key = 'image-%09d'.encode() % index
            imgbuf = txn.get(img_key)

            buf = six.BytesIO()
            buf.write(imgbuf)
            buf.seek(0)
            try:
                if self.img_channel == 3:
                    img = Image.open(buf).convert('RGB')  # for color image
                else:
                    img = Image.open(buf).convert('L')

            except IOError:
                print('Corrupted image for {}'.format(index))
                # make dummy image and dummy label for corrupted image.
                if self.img_channel == 3:
                    img = Image.new('RGB', (self.img_w, self.img_h))
                else:
                    img = Image.new('L', (self.img_w, self.img_h))
                label = '嫑'

            # We only train and evaluate on alphanumerics (or pre-defined character set in train.py)
            out_of_char = '[^{}]'.format(self.alphabet)
            label = re.sub(out_of_char, '', label)
            label = label.replace(' ', '')
            if self.ignore_chinese_punctuation:
                label = punctuation_mend(label)
            label = self.label_enocder(label)
            img = nd.array(np.array(img))
            img = self.pre_processing(img)
        return (img, label)
예제 #6
0
    def get_sample(self, index):
        index = self.data_list[index]
        with self.env.begin(write=False) as txn:
            label_key = 'label-%09d'.encode() % index
            label = txn.get(label_key).decode('utf-8')
            img_key = 'image-%09d'.encode() % index
            imgbuf = txn.get(img_key)

            buf = six.BytesIO()
            buf.write(imgbuf)
            buf.seek(0)
            if self.img_mode == 'RGB':
                img = Image.open(buf).convert('RGB')  # for color image
            elif self.img_mode == "GRAY":
                img = Image.open(buf).convert('L')
            else:
                raise NotImplementedError
            # We only train and evaluate on alphanumerics (or pre-defined character set in train.py)
            if self.remove_blank:
                label = label.replace(' ', '')
            if self.ignore_chinese_punctuation:
                label = punctuation_mend(label)
            img = np.array(img)
        return img, label
예제 #7
0
    def get_sample(self, index):
        index = self.data_list[index]
        with self.env.begin(write=False) as txn:
            label_key = 'label-%09d'.encode() % index
            label = txn.get(label_key).decode('utf-8')
            img_key = 'image-%09d'.encode() % index
            imgbuf = txn.get(img_key)

            buf = six.BytesIO()
            buf.write(imgbuf)
            buf.seek(0)
            try:
                if self.img_mode == 'RGB':
                    img = Image.open(buf).convert('RGB')  # for color image
                elif self.img_mode == "GRAY":
                    img = Image.open(buf).convert('L')
                else:
                    raise NotImplementedError
            except IOError:
                print('Corrupted image for {}'.format(index))
                # make dummy image and dummy label for corrupted image.
                if self.img_channel == 3:
                    img = Image.new('RGB', (self.img_w, self.img_h))
                else:
                    img = Image.new('L', (self.img_w, self.img_h))
                label = '嫑'

            # We only train and evaluate on alphanumerics (or pre-defined character set in train.py)
            out_of_char = '[^{}]'.format(self.alphabet)
            label = re.sub(out_of_char, '', label)
            if self.remove_blank:
                label = label.replace(' ', '')
            if self.ignore_chinese_punctuation:
                label = punctuation_mend(label)
            img = np.array(img)
        return img, label
예제 #8
0
        imageKey = 'image-%09d'.encode() % cnt
        labelKey = 'label-%09d'.encode() % cnt
        cache[imageKey] = imageBin
        cache[labelKey] = label.encode()

        if cnt % 1000 == 0:
            writeCache(env, cache)
            cache = {}
        cnt += 1
    nSamples = cnt - 1
    cache['num-samples'.encode()] = str(nSamples).encode()
    writeCache(env, cache)
    print('Created dataset with %d samples' % nSamples)


if __name__ == '__main__':
    data_list = [["/media/zj/资料/zj/dataset/train_linux.csv"]]
    save_path = './lmdb/train'
    os.makedirs(save_path, exist_ok=True)
    train_data_list, val_data_list = get_datalist(data_list,
                                                  val_data_path=data_list[0])
    train_data_list = train_data_list[0]
    alphabet = [x[1] for x in train_data_list]
    alphabet.extend([x[1] for x in val_data_list])
    alphabet = [punctuation_mend(x) for x in alphabet]
    alphabet = ''.join(sorted(set((''.join(alphabet)))))
    alphabet.replace(' ', '')
    np.save(os.path.join(save_path, 'alphabet.npy'), alphabet)
    createDataset(train_data_list, save_path)
    createDataset(val_data_list, save_path.replace('train', 'validation'))