Exemple #1
0
class Data(RNGDataFlow):
    def __init__(self, train_or_test, shuffle=True):
        assert train_or_test in ['train', 'test']
        fname_list = cfg.train_list if train_or_test == "train" else cfg.test_list
        self.train_or_test = train_or_test
        fname_list = [fname_list
                      ] if type(fname_list) is not list else fname_list

        self.imglist = []
        for fname in fname_list:
            self.imglist.extend(get_imglist(fname))

        self.shuffle = shuffle

        self.mapper = Mapper()

    def size(self):
        return len(self.imglist)

    def get_data(self):
        idxs = np.arange(len(self.imglist))
        if self.shuffle:
            self.rng.shuffle(idxs)
        for k in idxs:
            img_path = self.imglist[k]
            label_path = img_path.split('.')[0] + ".txt"
            img = misc.imread(img_path, 'L')
            if img.shape[0] != cfg.input_height:
                if cfg.input_width != None:
                    img = cv2.resize(img, (cfg.input_width, cfg.input_height))
                else:
                    scale = cfg.input_height / img.shape[0]
                    img = cv2.resize(img, None, fx=scale, fy=scale)
            feat = np.expand_dims(img, axis=2)
            with open(label_path) as f:
                content = f.readlines()
            label_cleaned = ''.join(
                [i for i in content[0] if i in cfg.dictionary])
            if label_cleaned == "":
                continue

            word_set = label_cleaned.split(' ')
            label = self.mapper.encode_string(label_cleaned)
            yield [feat, label]
Exemple #2
0
class TextDF(DataFlow):
    def __init__(self, dirname, dict_path, channel=1):
        self.dirname = dirname
        self.channel = channel
        self.filelists = [
            k for k in fs.recursive_walk(self.dirname) if k.endswith('.png')
        ]
        logger.info("Found {} png files ...".format(len(self.filelists)))

        self.mapper = Mapper(dict_path)

    def size(self):
        return len(self.filelists)

    def get_data(self):
        for filename in self.filelists:
            feat = misc.imread(filename, 'L')
            feat = np.expand_dims(feat, axis=2)
            label_filename = filename.replace("png", "txt")
            with open(label_filename) as label_file:
                content = label_file.readlines()
            yield [feat, self.mapper.encode_string(content[0])]