예제 #1
0
    def batch_generator(self, queue):
        """Takes a queue and enqueue batches in it
        """

        generator = GeneratorFromDict(language=self.language)
        while True:
            batch = []
            while len(batch) < self.batch_size:
                img, lbl = generator.next()
                batch.append((
                    resize_image(np.array(img.convert("L")),
                                 self.max_image_width)[0],
                    lbl,
                    label_to_array(lbl, self.char_vector),
                ))

            raw_batch_x, raw_batch_y, raw_batch_la = zip(*batch)

            batch_y = np.reshape(np.array(raw_batch_y), (-1))

            batch_dt = sparse_tuple_from(
                np.reshape(np.array(raw_batch_la), (-1)))

            raw_batch_x = np.swapaxes(raw_batch_x, 1, 2)

            batch_x = np.reshape(
                np.array(raw_batch_x),
                (len(raw_batch_x), self.max_image_width, 32, 1))
            if queue.qsize() < 20:
                queue.put((batch_y, batch_dt, batch_x))
            else:
                pass
class TextRecognition(Dataset):
    def __init__(self, count, textlength, dictpath):
        self.count = count
        fonts_dir = "/home/ldl/桌面/python-notebook/My_trdg/trdg/fonts/cn"
        fonts = [os.path.join(fonts_dir, i) for i in os.listdir(fonts_dir)]
        # dictpath = "/home/ldl/桌面/论文/文本识别/TextRecognitionDataGenerator/trdg/mydicts/all_4068.txt"
        img_dir = "/home/ldl/桌面/论文/文本识别/TextRecognitionDataGenerator/trdg/images"
        self.args = dict(count=self.count,
                         length=textlength,
                         allow_variable=True,
                         fonts=fonts,
                         language=dictpath,
                         size=64,
                         blur=2,
                         random_blur=True,
                         image_dir=img_dir,
                         background_type=[0, 1, 2, 3],
                         distorsion_type=[0, 1, 2],
                         text_color="#000000,#FF8F8F",
                         image_mode="L",
                         char_cat="",
                         space_width=[1, 2, 3, 4],
                         character_spacing=[0, 1, 2, 3, 4, 5])
        self.generator = GeneratorFromDict(**self.args)

    def __getitem__(self, index):
        try:
            img, label = self.generator.next()
        except StopIteration:
            self.generator = GeneratorFromDict(**self.args)
            img, label = self.generator.next()

        return img, label

    def __len__(self):
        return self.count