def __init__(self, dirPath, split, config):

        self.img_height = config['img_height']

        #with open(os.path.join(dirPath,'sets.json')) as f:
        with open(os.path.join('data', 'sets.json')) as f:
            set_list = json.load(f)[split]

        self.authors = defaultdict(list)
        self.lineIndex = []
        for page_idx, name in enumerate(set_list):
            lines, author = parseXML(
                os.path.join(dirPath, 'xmls', name + '.xml'))

            authorLines = len(self.authors[author])
            self.authors[author] += [
                (os.path.join(dirPath, 'forms', name + '.png'), ) + l
                for l in lines
            ]
            self.lineIndex += [(author, i + authorLines)
                               for i in range(len(lines))]

        char_set_path = config['char_file']
        with open(char_set_path) as f:
            char_set = json.load(f)
        self.char_to_idx = char_set['char_to_idx']

        self.augmentation = config[
            'augmentation'] if 'augmentation' in config else None
        self.normalized_dir = config[
            'cache_normalized'] if 'cache_normalized' in config else None
        if self.normalized_dir is not None:
            ensure_dir(self.normalized_dir)

        self.warning = False

        #DEBUG
        if 'overfit' in config and config['overfit']:
            self.lineIndex = self.lineIndex[:10]

        self.center = config[
            'center_pad']  #if 'center_pad' in config else True

        self.add_spaces = config[
            'add_spaces'] if 'add_spces' in config else False
    def __init__(self, dirPath, split, config):
        if 'split' in config:
            split = config['split']

        self.img_height = config['img_height']
        self.a_batch_size = config['a_batch_size']

        #with open(os.path.join(dirPath,'sets.json')) as f:
        with open(os.path.join('data', 'wordMixed_sets.json')) as f:
            set_list = json.load(f)[split]

        #self.authors = defaultdict(list)
        self.w_authors = defaultdict(list)
        allLines = defaultdict(dict)
        self.lineIndex = []
        for word_idx, word_id in enumerate(set_list):

            line_id = word_id[:word_id.rfind('-')]
            page_id = line_id[:line_id.rfind('-')]
            if page_id not in allLines:
                w_lines, lines, author = parseXML(
                    os.path.join(dirPath, 'xmls', page_id + '.xml'))
                for w_line in w_lines:
                    for bounds, trans, id in w_line:
                        allLines[page_id][id] = (bounds, trans)
            bounds, trans = allLines[page_id][word_id]
            self.w_authors[author].append(
                (os.path.join(dirPath, 'forms',
                              page_id + '.png'), bounds, trans, word_id))

        #minLines=99999
        #for author,lines in self.authors.items():
        #print('{} {}'.format(author,len(lines)))
        #minLines = min(minLines,len(lines))
        #maxCombs = int(nCr(minLines,self.a_batch_size)*1.2)
        for author, words in self.w_authors.items():
            #if split=='train':
            #    combs=list(itertools.combinations(list(range(len(lines))),self.a_batch_size))
            #    np.random.shuffle(combs)
            #    self.lineIndex += [(author,c) for c in combs[:maxCombs]]
            #else:
            for i in range(len(words) // self.a_batch_size):
                ls = []
                for n in range(self.a_batch_size):
                    ls.append(self.a_batch_size * i + n)
                inst = (author, ls)
                self.lineIndex.append(inst)
            leftover = len(words) % self.a_batch_size
            fill = self.a_batch_size - leftover
            last = []
            for i in range(fill):
                last.append(i)
            for i in range(leftover):
                last.append(len(words) - (1 + i))
            self.lineIndex.append((author, last))

            #if split=='train':
            #    ss = set(self.lineIndex)
        self.authors = self.w_authors.keys()

        char_set_path = config['char_file']
        with open(char_set_path) as f:
            char_set = json.load(f)
        self.char_to_idx = char_set['char_to_idx']
        self.augmentation = config[
            'augmentation'] if 'augmentation' in config else None
        self.warning = False

        if 'style_loc' in config:
            by_author_styles = defaultdict(list)
            by_author_all_ids = defaultdict(set)
            style_loc = config['style_loc']
            if style_loc[-1] != '*':
                style_loc += '*'
            all_style_files = glob(style_loc)
            assert (len(all_style_files) > 0)
            for loc in all_style_files:
                #print('loading '+loc)
                with open(loc, 'rb') as f:
                    styles = pickle.load(f)
                for i in range(len(styles['authors'])):
                    by_author_styles[styles['authors'][i]].append(
                        (styles['styles'][i], styles['ids'][i]))
                    by_author_all_ids[styles['authors'][i]].update(
                        styles['ids'][i])

            self.styles = defaultdict(lambda: defaultdict(list))
            for author in by_author_styles:
                for id in by_author_all_ids[author]:
                    for style, ids in by_author_styles[author]:
                        if id not in ids:
                            self.styles[author][id].append(style)

        else:
            self.styles = None

        self.mask_post = config['mask_post'] if 'mask_post' in config else []

        #DEBUG
        if 'overfit' in config and config['overfit']:
            self.lineIndex = self.lineIndex[:10]

        self.npr = np.random.RandomState(1234)
    def __init__(self, dirPath, split, config):

        self.img_height = config['img_height']
        self.a_batch_size = config['a_batch_size']
        num_style_per_word = config[
            'num_style_per_word'] if 'num_style_per_word' in config else 10

        #with open(os.path.join(dirPath,'sets.json')) as f:
        with open(os.path.join('data', 'sets.json')) as f:
            set_list = json.load(f)[split]

        #self.authors = defaultdict(list)
        self.w_authors = defaultdict(list)
        for page_idx, name in enumerate(set_list):
            w_lines, lines, author = parseXML(
                os.path.join(dirPath, 'xmls', name + '.xml'))

            #authorLines = len(self.authors[author])
            #self.authors[author] += [(os.path.join(dirPath,'forms',name+'.png'),)+l for l in lines]
            for words in w_lines:
                self.w_authors[author] += [
                    (os.path.join(dirPath, 'forms', name + '.png'), ) + w
                    for w in words
                ]
            #self.lineIndex += [(author,i+authorLines) for i in range(len(lines))]
        #minLines=99999
        #for author,lines in self.authors.items():
        #print('{} {}'.format(author,len(lines)))
        #minLines = min(minLines,len(lines))
        #maxCombs = int(nCr(minLines,self.a_batch_size)*1.2)
        min_words = 999999
        for author, words in self.w_authors.items():
            min_words = min(min_words, len(words))
        self.a_batch_size = min(self.a_batch_size, min_words - 1)
        npr = np.random.RandomState(123)
        self.index = []
        for author, words in self.w_authors.items():
            #if split=='train':
            #    combs=list(itertools.combinations(list(range(len(lines))),self.a_batch_size))
            #    np.random.shuffle(combs)
            #    self.lineIndex += [(author,c) for c in combs[:maxCombs]]
            #else:
            for wi in range(len(words)):
                less_words = [wj for wj in range(len(words)) if wj != wi]
                choose = scipy.misc.comb(len(less_words), self.a_batch_size)
                for i in range(
                        int(min(1 + (choose - 1) / 2, num_style_per_word))):
                    indexes = npr.choice(len(less_words), self.a_batch_size)
                    selection = [less_words[i] for i in indexes]
                self.index.append((author, selection))

        char_set_path = config['char_file']
        with open(char_set_path) as f:
            char_set = json.load(f)
        self.char_to_idx = char_set['char_to_idx']
        self.augmentation = config[
            'augmentation'] if 'augmentation' in config else None
        self.warning = False

        self.mask_post = config['mask_post'] if 'mask_post' in config else []

        #DEBUG
        if 'overfit' in config and config['overfit']:
            self.index = self.index[:10]
Ejemplo n.º 4
0
    def __init__(self, dirPath, split, config):

        self.img_height = config['img_height']
        self.batch_size = config[
            'a_batch_size'] if 'a_batch_size' in config else 1
        #assert(config['batch_size']==1)

        #with open(os.path.join(dirPath,'sets.json')) as f:
        with open(os.path.join('data', 'lineMixed_sets.json')) as f:
            set_list = json.load(f)[split]

        self.authors = defaultdict(list)
        allLines = defaultdict(dict)
        self.lineIndex = []
        for line_idx, line_id in enumerate(set_list):
            page_id = line_id[:line_id.rfind('-')]
            if page_id not in allLines:
                lines, author = parseXML(
                    os.path.join(dirPath, 'xmls', page_id + '.xml'))
                for bounds, trans, id in lines:
                    allLines[page_id][id] = (bounds, trans)
            bounds, trans = allLines[page_id][line_id]
            self.authors[author].append(
                (os.path.join(dirPath, 'forms',
                              page_id + '.png'), bounds, trans))
            #self.lineIndex += [(author,i+authorLines) for i in range(len(lines))]
        #minLines=99999
        #for author,lines in self.authors.items():
        #print('{} {}'.format(author,len(lines)))
        #minLines = min(minLines,len(lines))
        #maxCombs = int(nCr(minLines,self.batch_size)*1.2)
        short = config['short'] if 'short' in config else False
        for author, lines in self.authors.items():
            #if split=='train':
            #    combs=list(itertools.combinations(list(range(len(lines))),self.batch_size))
            #    np.random.shuffle(combs)
            #    self.lineIndex += [(author,c) for c in combs[:maxCombs]]
            #else:
            for i in range(len(lines) // self.batch_size):
                ls = []
                for n in range(self.batch_size):
                    ls.append(self.batch_size * i + n)
                inst = (author, ls)
                self.lineIndex.append(inst)
                if short and i >= short:
                    break
            if short and i >= short:
                continue
            leftover = len(lines) % self.batch_size
            fill = self.batch_size - leftover
            last = []
            for i in range(fill):
                last.append(i)
            for i in range(leftover):
                last.append(len(lines) - (1 + i))
            self.lineIndex.append((author, last))

            #if split=='train':
            #    ss = set(self.lineIndex)
        #self.authors = self.authors.keys()

        char_set_path = config['char_file']
        with open(char_set_path) as f:
            char_set = json.load(f)
        self.char_to_idx = char_set['char_to_idx']
        self.augmentation = config[
            'augmentation'] if 'augmentation' in config else None
        self.warning = False

        #DEBUG
        if 'overfit' in config and config['overfit']:
            self.lineIndex = self.lineIndex[:10]

        self.center = False  #config['center_pad'] #if 'center_pad' in config else True

        self.mask_post = config['mask_post'] if 'mask_post' in config else []
Ejemplo n.º 5
0
    def __init__(self, dirPath, split, config):
        if 'split' in config:
            split = config['split']

        self.img_height = config['img_height']
        self.batch_size = config['a_batch_size']
        self.no_spaces = config['no_spaces'] if 'no_spaces' in config else False
        self.max_width = config['max_width'] if  'max_width' in config else 3000
        #assert(config['batch_size']==1)
        self.warning=False

        self.triplet = config['triplet'] if 'triplet' in config else False
        if self.triplet:
            self.triplet_author_size = config['triplet_author_size']
            self.triplet_sample_size = config['triplet_sample_size']

        only_author = config['only_author'] if 'only_author' in config else None
        skip_author = config['skip_author'] if 'skip_author' in config else None

        #with open(os.path.join(dirPath,'sets.json')) as f:
        with open(os.path.join('data','sets.json')) as f:
            set_list = json.load(f)[split]

        self.authors = defaultdict(list)
        self.lineIndex = []
        self.max_char_len=0
        self.author_list=set()
        for page_idx, name in enumerate(set_list):
            lines,author = parseXML(os.path.join(dirPath,'xmls',name+'.xml'))
            self.author_list.add(author)
            if only_author is not None and type(only_author) is int and page_idx==only_author:
                only_author=author
                print('Only author: {}'.format(only_author))
            if only_author is not None and author!=only_author:
                continue
            if skip_author is not None and author==skip_author:
                continue
            self.max_char_len= max([self.max_char_len]+[len(l[1]) for l in lines])
            
            authorLines = len(self.authors[author])
            self.authors[author] += [(os.path.join(dirPath,'forms',name+'.png'),)+l for l in lines]
            #self.lineIndex += [(author,i+authorLines) for i in range(len(lines))]
        self.author_list = list(self.author_list)
        self.author_list.sort()
        #minLines=99999
        #for author,lines in self.authors.items():
            #print('{} {}'.format(author,len(lines)))
            #minLines = min(minLines,len(lines))
        #maxCombs = int(nCr(minLines,self.batch_size)*1.2)
        short = config['short'] if 'short' in config else False
        for author,lines in self.authors.items():
            #if split=='train':
            #    combs=list(itertools.combinations(list(range(len(lines))),self.batch_size))
            #    np.random.shuffle(combs)
            #    self.lineIndex += [(author,c) for c in combs[:maxCombs]]
            #else:
            for i in range(len(lines)//self.batch_size):
                ls=[]
                for n in range(self.batch_size):
                    ls.append(self.batch_size*i+n)
                inst = (author,ls)
                self.lineIndex.append(inst)
                if short and i>=short:
                    break
            if short and i>=short:
                continue
            leftover = len(lines)%self.batch_size
            fill = self.batch_size-leftover
            last=[]
            for i in range(fill):
                last.append(i)
            for i in range(leftover):
                last.append(len(lines)-(1+i))
            self.lineIndex.append((author,last))
        self.fg_masks_dir = config['fg_masks_dir'] if 'fg_masks_dir' in config else None

        if self.fg_masks_dir is not None:
            if self.fg_masks_dir[-1]=='/':
                self.fg_masks_dir = self.fg_masks_dir[:-1]
            self.fg_masks_dir+='_{}'.format(self.max_width)
            ensure_dir(self.fg_masks_dir)
            for author,lines in self.lineIndex:
                for line in lines:
                    img_path, lb, gt = self.authors[author][line]
                    fg_path = os.path.join(self.fg_masks_dir,'{}_{}.png'.format(author,line))
                    if not os.path.exists(fg_path):
                        img = cv2.imread(img_path,0)[lb[0]:lb[1],lb[2]:lb[3]] #read as grayscale, crop line

                        if img.shape[0] != self.img_height:
                            if img.shape[0] < self.img_height and not self.warning:
                                self.warning = True
                                print("WARNING: upsampling image to fit size")
                            percent = float(self.img_height) / img.shape[0]
                            if img.shape[1]*percent > self.max_width:
                                percent = self.max_width/img.shape[1]
                            img = cv2.resize(img, (0,0), fx=percent, fy=percent, interpolation = cv2.INTER_CUBIC)
                            if img.shape[0]<self.img_height:
                                diff = self.img_height-img.shape[0]
                                img = np.pad(img,((diff//2,diff//2+diff%2),(0,0)),'constant',constant_values=255)

                        th,binarized = cv2.threshold(img,0,255,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
                        binarized = 255-binarized
                        ele = cv2.getStructuringElement(  cv2.MORPH_ELLIPSE, (9,9) )
                        binarized = cv2.dilate(binarized,ele)
                        cv2.imwrite(fg_path,binarized)
                        print('saved fg mask: {}'.format(fg_path))
                        #test_path = os.path.join(fg_masks_dir,'{}_{}_test.png'.format(author,line))
                        ##print(img.shape)
                        #img = np.stack((img,img,img),axis=2)
                        #img[:,:,0]=binarized
                        #cv2.imwrite(test_path,img)
                        #print('saved fg mask: {}'.format(fg_path))

            #if split=='train':
            #    ss = set(self.lineIndex)
        #self.authors = self.authors.keys()
                

        char_set_path = config['char_file']
        with open(char_set_path) as f:
            char_set = json.load(f)
        self.char_to_idx = char_set['char_to_idx']

        self.augmentation = config['augmentation'] if 'augmentation' in config else None
        self.normalized_dir = config['cache_normalized'] if 'cache_normalized' in config else None
        if self.normalized_dir is not None:
            ensure_dir(self.normalized_dir)
        self.max_strech=0.4
        self.max_rot_rad= 45/180 * math.pi

        self.remove_bg = config['remove_bg'] if 'remove_bg' in config else False
        self.include_stroke_aug = config['include_stroke_aug'] if 'include_stroke_aug' in config else False

        #DEBUG
        if 'overfit' in config and config['overfit']:
            self.lineIndex = self.lineIndex[:10]

        self.center = False #config['center_pad'] #if 'center_pad' in config else True

        if 'style_loc' in config:
            by_author_styles=defaultdict(list)
            by_author_all_ids=defaultdict(set)
            style_loc = config['style_loc']
            if style_loc[-1]!='*':
                style_loc+='*'
            all_style_files = glob(style_loc)
            assert( len(all_style_files)>0)
            for loc in all_style_files:
                #print('loading '+loc)
                with open(loc,'rb') as f:
                    styles = pickle.load(f)
                for i in range(len(styles['authors'])):
                    by_author_styles[styles['authors'][i]].append((styles['styles'][i],styles['ids'][i]))
                    by_author_all_ids[styles['authors'][i]].update(styles['ids'][i])

            self.styles = defaultdict(lambda: defaultdict(list))
            for author in by_author_styles:
                for id in by_author_all_ids[author]:
                    for style, ids in by_author_styles[author]:
                        if id not in ids:
                            self.styles[author][id].append(style)

            for author in self.authors:
                assert(author in self.styles)
        else:
            self.styles=None

        if 'spaced_loc' in config:
            with open(config['spaced_loc'],'rb') as f:
                self.spaced_by_name = pickle.load(f)
            #for name,v in spaced_by_name.items():
            #    author, id = name.split('_')
        else:
            self.spaced_by_name = None
            self.identity_spaced = config['no_spacing_for_spaced'] if 'no_spacing_for_spaced' in config else False

        self.mask_post = config['mask_post'] if 'mask_post' in config else []
        self.mask_random = config['mask_random'] if 'mask_random' in config else False
Ejemplo n.º 6
0
import os, sys, json
from utils.parseIAM import getLineBoundaries as parseXML

dirPath = sys.argv[1]
outPath = sys.argv[2]

with open(os.path.join('data', 'sets.json')) as f:
    set_list = json.load(f)['test']

texts = []
for page_idx, name in enumerate(set_list):
    lines, author = parseXML(os.path.join(dirPath, 'xmls', name + '.xml'))
    texts += [t for b, t in lines]

with open(outPath, 'w') as out:
    for t in texts:
        out.write('{}\n'.format(t))