def __getitem__(self, idx): img = None while img is None: img = img_f.imread(self.index[idx]) idx = (img + 7) % len(self.index) if img.shape[0] != self.size or img.shape[1] != self.size: img = img_f.resize(img, (self.size, self.size), degree=0) img = torch.from_numpy(img).permute(2, 0, 1).float() if img.max() >= 220: img = img / 255 #I think different versions of Pytorch do the conversion from bool to float differently if img.size( 0) == 4: #alpha channel, we'll fill it with a QR code image if self.qr_dataset is None: self.qr_dataset = SimpleQRDataset(None, 'train', { 'str_len': 17, 'final_size': self.size, 'noise': False }) qr_img = self.qr_dataset[0]['image'] qr_img = (qr_img + 1) / 2 qr_img = qr_img.expand(3, -1, -1).clone() #convert to color qr_img[:, (img[3] > 0).bool()] = img[:3, ( img[3] > 0).bool()] #add image where alpha is >0 img = qr_img img = self.transform(img) img = img * 2 - 1 return img, None
def __getitem__(self, idx): qr = qrcode.QRCode( version=1, error_correction=self.error_level, box_size=self.box_size, border=self.border, mask_pattern=self.mask_pattern, ) if self.str_len is not None: #length = random.randrange(self.min_str_len,self.str_len+1) length = self.str_len gt_char = ''.join(random.choice(self.characters) for i in range(length)) else: if self.indexes is not None: idx = self.indexes[idx] gt_char = '{}'.format(idx) qr.add_data(gt_char) qr.make(fit=True) img = qr.make_image(fill_color="black", back_color="white") img = np.array(img) if self.final_size is not None: img = img_f.resize(img,(self.final_size,self.final_size),degree=0) # Slight noise if self.noise: img = data_utils.gaussian_noise(img.astype(np.uint)*255, max_intensity=1) img = torch.from_numpy(img)[None,...].float() if img.max() == 255: img=img/255 #I think different versions of Pytorch do the conversion from bool to float differently img = img * 2 - 1 targetchar = torch.LongTensor(self.str_len).fill_(0) for i,c in enumerate(gt_char): targetchar[i]=self.char_to_index[c] targetvalid = torch.FloatTensor([1]) if not self.mask is None: masked_img = self.mask * img else: masked_img = None if False: plot(img, mask=False) plot(masked_img, mask=False) plot(self.mask, mask=True) return { "image": img, "gt_char": gt_char, 'targetchar': targetchar, 'targetvalid': targetvalid, "masked_img": masked_img }
def makeQR(text, size=None): qr = qrcode.QRCode( version=1, error_correction=qrcode.constants.ERROR_CORRECT_L, box_size=11, border=2, ) qr.add_data(text) qr.make(fit=True) img = qr.make_image(fill_color="black", back_color="white") img = np.array(img)[:, :, None].repeat(3, axis=2).astype(np.uint8) * 255 if size is not None: img = cv2.resize(img, (size, size), degree=0) return img
def generate_qr_code(self, gt_char): qr = qrcode.QRCode( version=1, error_correction=qrcode.constants.ERROR_CORRECT_L, box_size=1, border=self.border, mask_pattern=self.mask_pattern, ) gt_char = str(gt_char) qr.add_data(gt_char) qr.make(fit=True) #WARNING img = qr.make_image(fill_color="black", back_color="white") #.resize(self.qr_size) #print(np.array(img).shape) img = np.array(img).astype(np.uint8) * 255 img = img_f.resize(img, (self.final_size, self.final_size), degree=0) img = img[:, :, None].repeat(3, axis=2) return {"gt_char": gt_char, "image_undistorted": img}
#:'allenywang': lambda qr,a: QRMatrix('decode',image=a).decode(), 'pyzbar': lambda qr, a: zbar_decode(a) } images = list(Path("./imagenet/images/dogs").rglob("*")) num_images = len(images) test_strings = [ "short", "medium 4io4\:][", "long sdfjka349:fg,.<>fgok4t-={}.///gf" ] for text in test_strings: results = defaultdict(lambda: 0) avg_max_interpolation = defaultdict(list) qr_image = makeQR(text, 256) for imN, f in enumerate(images): background = cv2.imread(str(f)) background = cv2.resize(background, (256, 256)) for name, qr in qr_decoders.items(): qr_d = qr_decode[name] max_hit = 0 for p in range(20): mix = round(p * .05, 2) #s = superimpose(background, qr_image, mix) s = contrast(qr_image, round(p * .05, 2)) cv2.imwrite('tmp.png', s) s_ = cv2.imread('tmp.png') res = qr_d(qr, s_) #if mix>0.5: # import pdb;pdb.set_trace() #print('{} {} : {}'.format(p*.05,name,res)) if res == text:
def __init__(self, dirPath=None, split=None, config=None, images=None): super(FUNSDGraphPair, self).__init__(dirPath, split, config, images) self.only_types = None self.split_to_lines = config['split_to_lines'] if images is not None: self.images = images else: if 'overfit' in config and config['overfit']: splitFile = 'overfit_split.json' else: splitFile = 'FUNSD_train_valid_test_split.json' with open(os.path.join(splitFile)) as f: #if split=='valid' or split=='validation': # trainTest='train' #else: # trainTest=split readFile = json.loads(f.read()) if type(split) is str: toUse = readFile[split] imagesAndAnn = [] imageDir = os.path.join(dirPath, toUse['root'], 'images') annDir = os.path.join(dirPath, toUse['root'], 'annotations') for name in toUse['images']: imagesAndAnn.append( (name + '.png', os.path.join(imageDir, name + '.png'), os.path.join(annDir, name + '.json'))) elif type(split) is list: imagesAndAnn = [] for spstr in split: toUse = readFile[spstr] imageDir = os.path.join(dirPath, toUse['root'], 'images') annDir = os.path.join(dirPath, toUse['root'], 'annotations') for name in toUse['images']: imagesAndAnn.append( (name + '.png', os.path.join(imageDir, name + '.png'), os.path.join(annDir, name + '.json'))) else: print("Error, unknown split {}".format(split)) exit() self.images = [] for imageName, imagePath, jsonPath in imagesAndAnn: org_path = imagePath if self.cache_resized: path = os.path.join(self.cache_path, imageName) else: path = org_path if os.path.exists(jsonPath): rescale = 1.0 if self.cache_resized: rescale = self.rescale_range[1] if not os.path.exists(path): org_img = img_f.imread(org_path) if org_img is None: print('WARNING, could not read {}'.format( org_img)) continue resized = img_f.resize( org_img, (0, 0), fx=self.rescale_range[1], fy=self.rescale_range[1], ) img_f.imwrite(path, resized) self.images.append({ 'id': imageName, 'imagePath': path, 'annotationPath': jsonPath, 'rescaled': rescale, 'imageName': imageName[:imageName.rfind('.')] }) self.only_types = None self.errors = [] self.classMap = { 'header': 16, 'question': 17, 'answer': 18, 'other': 19 } self.index_class_map = ['header', 'question', 'answer', 'other']
def __init__(self, dirPath=None, split=None, config=None, images=None): super(FormsGraphPair, self).__init__(dirPath, split, config, images) if 'only_types' in config: self.only_types = config['only_types'] else: self.only_types = None #print( self.only_types) if 'swap_circle' in config: self.swapCircle = config['swap_circle'] else: self.swapCircle = False self.special_dataset = config[ 'special_dataset'] if 'special_dataset' in config else None if 'simple_dataset' in config and config['simple_dataset']: self.special_dataset = 'simple' if images is not None: self.images = images else: if self.special_dataset is not None: splitFile = self.special_dataset + '_train_valid_test_split.json' else: splitFile = 'train_valid_test_split.json' with open(os.path.join(dirPath, splitFile)) as f: readFile = json.loads(f.read()) if type(split) is str: groupsToUse = readFile[split] elif type(split) is list: groupsToUse = {} for spstr in split: newGroups = readFile[spstr] groupsToUse.update(newGroups) else: print("Error, unknown split {}".format(split)) exit() self.images = [] groupNames = list(groupsToUse.keys()) groupNames.sort() for groupName in groupNames: imageNames = groupsToUse[groupName] #print('{} {}'.format(groupName, imageNames)) #oneonly=False if groupName in SKIP: print('Skipped group {}'.format(groupName)) continue # if groupName in ONE_DONE: # oneonly=True # with open(os.path.join(dirPath,'groups',groupName,'template'+groupName+'.json')) as f: # T_annotations = json.loads(f.read()) # else: for imageName in imageNames: #if oneonly and T_annotations['imageFilename']!=imageName: # #print('skipped {} {}'.format(imageName,groupName)) # continue #elif oneonly: # print('only {} from {}'.format(imageName,groupName)) org_path = os.path.join(dirPath, 'groups', groupName, imageName) if self.cache_resized: path = os.path.join(self.cache_path, imageName) else: path = org_path jsonPath = org_path[:org_path.rfind('.')] + '.json' #print(jsonPath) if os.path.exists(jsonPath): rescale = 1.0 if self.cache_resized: rescale = self.rescale_range[1] if not os.path.exists(path): org_img = img_f.imread(org_path) if org_img is None: print('WARNING, could not read {}'.format( org_img)) continue #target_dim1 = self.rescale_range[1] #target_dim0 = int(org_img.shape[0]/float(org_img.shape[1]) * target_dim1) #resized = img_f.resize(org_img,(target_dim1, target_dim0)) resized = img_f.resize( org_img, (0, 0), fx=self.rescale_range[1], fy=self.rescale_range[1], ) img_f.imwrite(path, resized) #rescale = target_dim1/float(org_img.shape[1]) #elif self.cache_resized: #print(jsonPath) #with open(jsonPath) as f: # annotations = json.loads(f.read()) #imW = annotations['width'] #target_dim1 = self.rescale_range[1] #rescale = target_dim1/float(imW) #print('addint {}'.format(imageName)) self.images.append({ 'id': imageName, 'imagePath': path, 'annotationPath': jsonPath, 'rescaled': rescale, 'imageName': imageName[:imageName.rfind('.')] }) #else: # print('couldnt find {}'.format(jsonPath)) # with open(path+'.json') as f: # annotations = json.loads(f.read()) # imH = annotations['height'] # imW = annotations['width'] # #startCount=len(self.instances) # for bb in annotations['textBBs']: self.no_blanks = config['no_blanks'] if 'no_blanks' in config else False self.use_paired_class = config[ 'use_paired_class'] if 'use_paired_class' in config else False if 'no_print_fields' in config: self.no_print_fields = config['no_print_fields'] else: self.no_print_fields = False self.no_graphics = config[ 'no_graphics'] if 'no_graphics' in config else False self.only_opposite_pairs = config[ 'only_opposite_pairs'] if 'only_opposite_pairs' in config else False self.group_only_same = config[ 'group_only_same'] if 'group_only_same' in config else False self.no_groups = config['no_groups'] if 'no_groups' in config else False assert (not self.no_groups or not self.group_only_same) if (self.group_only_same or self.no_groups) and self.only_opposite_pairs: print('Warning, you may want only_opposite_pairs off') if (self.group_only_same or self.no_groups) and not self.rotate: print('Warning, you may want rotation on') self.onlyFormStuff = False self.errors = [] self.useClasses = config[ 'use_classes'] if 'use_classes' in config else [] self.classMap = { 'textGeneric': 13, 'fieldGeneric': 14, } for i, clas in enumerate(self.useClasses): self.classMap[clas] = i + 15 if not self.no_blanks: self.classMap['blank'] = 15 + len(self.useClasses) if self.use_paired_class: self.classMap['paired'] = 15 + len( self.useClasses) + (0 if self.no_blanks else 1)
def getitem(self, index, scaleP=None, cropPoint=None): if self.useRandomAugProb is not None and np.random.rand( ) < self.useRandomAugProb and scaleP is None and cropPoint is None: return self.getRandomImage() ##ticFull=timeit.default_timer() imagePath = self.images[index]['imagePath'] imageName = self.images[index]['imageName'] annotationPath = self.images[index]['annotationPath'] #print(annotationPath) rescaled = self.images[index]['rescaled'] with open(annotationPath) as annFile: annotations = json.loads(annFile.read()) ##tic=timeit.default_timer() np_img = img_f.imread(imagePath, 1 if self.color else 0) #/255.0 if np_img.max() < 200: np_img *= 255 if np_img is None or np_img.shape[0] == 0: print("ERROR, could not open " + imagePath) return self.__getitem__((index + 1) % self.__len__()) if scaleP is None: s = np.random.uniform(self.rescale_range[0], self.rescale_range[1]) else: s = scaleP partial_rescale = s / rescaled if self.transform is None: #we're doing the whole image #this is a check to be sure we don't send too big images through pixel_count = partial_rescale * partial_rescale * np_img.shape[ 0] * np_img.shape[1] if pixel_count > self.pixel_count_thresh: partial_rescale = math.sqrt(partial_rescale * partial_rescale * self.pixel_count_thresh / pixel_count) print('{} exceed thresh: {}: {}, new {}: {}'.format( imageName, s, pixel_count, rescaled * partial_rescale, partial_rescale * partial_rescale * np_img.shape[0] * np_img.shape[1])) s = rescaled * partial_rescale max_dim = partial_rescale * max(np_img.shape[0], np_img.shape[1]) if max_dim > self.max_dim_thresh: partial_rescale = partial_rescale * (self.max_dim_thresh / max_dim) print('{} exceed thresh: {}: {}, new {}: {}'.format( imageName, s, max_dim, rescaled * partial_rescale, partial_rescale * max(np_img.shape[0], np_img.shape[1]))) s = rescaled * partial_rescale #print('! dataset transformed {}'.format(np_img.shape)) ##tic=timeit.default_timer() #np_img = img_f.resize(np_img,(target_dim1, target_dim0)) if np_img is not None: #print('! dataset not none...') np_img = img_f.resize( np_img, (0, 0), fx=partial_rescale, fy=partial_rescale, ) #print('! dataset resize') if len(np_img.shape) == 2: np_img = np_img[..., None] #add 'color' channel if self.color and np_img.shape[2] == 1: np_img = np.repeat(np_img, 3, axis=2) ##print('resize: {} [{}, {}]'.format(timeit.default_timer()-tic,np_img.shape[0],np_img.shape[1])) bbs, line_gts, point_gts, pixel_gt, numClasses, numNeighbors, pairs = self.parseAnn( np_img, annotations, s, imagePath) if self.coordConv: #add absolute position information xs = 255 * np.arange(np_img.shape[1]) / (np_img.shape[1]) xs = np.repeat(xs[None, :, None], np_img.shape[0], axis=0) ys = 255 * np.arange(np_img.shape[0]) / (np_img.shape[0]) ys = np.repeat(ys[:, None, None], np_img.shape[1], axis=1) np_img = np.concatenate( (np_img, xs.astype(np_img.dtype), ys.astype(np_img.dtype)), axis=2) ##ticTr=timeit.default_timer() if self.transform is not None: pairs = None out, cropPoint = self.transform( { "img": np_img, "bb_gt": bbs, "bb_auxs": numNeighbors, "line_gt": line_gts, "point_gt": point_gts, "pixel_gt": pixel_gt, }, cropPoint) np_img = out['img'] bbs = out['bb_gt'] numNeighbors = out['bb_auxs'] #if 'table_points' in out['point_gt']: # table_points = out['point_gt']['table_points'] #else: # table_points=None point_gts = out['point_gt'] pixel_gt = out['pixel_gt'] #start_of_line = out['line_gt']['start_of_line'] #end_of_line = out['line_gt']['end_of_line'] line_gts = out['line_gt'] ##tic=timeit.default_timer() if self.color: np_img[:, :, :3] = augmentation.apply_random_color_rotation( np_img[:, :, :3]) np_img[:, :, :3] = augmentation.apply_tensmeyer_brightness( np_img[:, :, :3]) else: np_img[:, :, 0:1] = augmentation.apply_tensmeyer_brightness( np_img[:, :, 0:1]) ##print('augmentation: {}'.format(timeit.default_timer()-tic)) ##print('transfrm: {} [{}, {}]'.format(timeit.default_timer()-ticTr,org_img.shape[0],org_img.shape[1])) #if len(np_img.shape)==2: # img=np_img[None,None,:,:] #add "color" channel and batch #else: img = np_img.transpose( [2, 0, 1])[None, ...] #from [row,col,color] to [batch,color,row,col] img = img.astype(np.float32) img = torch.from_numpy(img) img = 1.0 - img / 128.0 #ideally the median value would be 0 if pixel_gt is not None: pixel_gt = pixel_gt.transpose([2, 0, 1])[None, ...] pixel_gt = torch.from_numpy(pixel_gt) #start_of_line = None if start_of_line is None or start_of_line.shape[1] == 0 else torch.from_numpy(start_of_line) #end_of_line = None if end_of_line is None or end_of_line.shape[1] == 0 else torch.from_numpy(end_of_line) if line_gts is not None: for name in line_gts: line_gts[name] = None if line_gts[name] is None or line_gts[ name].shape[1] == 0 else torch.from_numpy(line_gts[name]) #bbs = None if bbs.shape[1] == 0 else torch.from_numpy(bbs) bbs = convertBBs(bbs, self.rotate, numClasses) if numNeighbors is not None and len(numNeighbors) > 0: numNeighbors = torch.tensor(numNeighbors)[None, :] #add batch dim else: numNeighbors = None #start_of_line = convertLines(start_of_line,numClasses) #end_of_line = convertLines(end_of_line,numClasses) if point_gts is not None: for name in point_gts: #if table_points is not None: #table_points = None if table_points.shape[1] == 0 else torch.from_numpy(table_points) if point_gts[name] is not None: point_gts[name] = None if point_gts[name].shape[ 1] == 0 else torch.from_numpy(point_gts[name]) ##print('__getitem__: '+str(timeit.default_timer()-ticFull)) if self.only_types is None: return { "img": img, "bb_gt": bbs, "num_neighbors": numNeighbors, "line_gt": line_gts, "point_gt": point_gts, "pixel_gt": pixel_gt, "imgName": imageName, "scale": s, "cropPoint": cropPoint, "pairs": pairs } else: if 'boxes' not in self.only_types or not self.only_types['boxes']: bbs = None line_gt = {} if 'line' in self.only_types: for ent in self.only_types['line']: if type(ent) == list: toComb = [] for inst in ent[1:]: einst = line_gts[inst] if einst is not None: toComb.append(einst) if len(toComb) > 0: comb = torch.cat(toComb, dim=1) line_gt[ent[0]] = comb else: line_gt[ent[0]] = None else: line_gt[ent] = line_gts[ent] point_gt = {} if 'point' in self.only_types: for ent in self.only_types['point']: if type(ent) == list: toComb = [] for inst in ent[1:]: einst = point_gts[inst] if einst is not None: toComb.append(einst) if len(toComb) > 0: comb = torch.cat(toComb, dim=1) point_gt[ent[0]] = comb else: line_gt[ent[0]] = None else: point_gt[ent] = point_gts[ent] pixel_gtR = None #for ent in self.only_types['pixel']: # if type(ent)==list: # comb = ent[1] # for inst in ent[2:]: # comb = (comb + inst)==2 #:eq(2) #pixel-wise AND # pixel_gt[ent[0]]=comb # else: # pixel_gt[ent]=eval(ent) if 'pixel' in self.only_types: # and self.only_types['pixel'][0]=='table_pixels': pixel_gtR = pixel_gt return { "img": img, "bb_gt": bbs, "num_neighbors": numNeighbors, "line_gt": line_gt, "point_gt": point_gt, "pixel_gt": pixel_gtR, "imgName": imageName, "scale": s, "cropPoint": cropPoint, "pairs": pairs, }
def getitem(self, index, scaleP=None, cropPoint=None): ##ticFull=timeit.default_timer() imagePath = self.images[index]['imagePath'] imageName = self.images[index]['imageName'] annotationPath = self.images[index]['annotationPath'] #print(annotationPath) rescaled = self.images[index]['rescaled'] with open(annotationPath) as annFile: annotations = json.loads(annFile.read()) ##tic=timeit.default_timer() np_img = img_f.imread(imagePath, 1 if self.color else 0) #*255.0 if np_img.max() < 200: np_img *= 255 if np_img is None or np_img.shape[0] == 0: print("ERROR, could not open " + imagePath) return self.__getitem__((index + 1) % self.__len__()) if scaleP is None: s = np.random.uniform(self.rescale_range[0], self.rescale_range[1]) else: s = scaleP partial_rescale = s / rescaled if self.transform is None: #we're doing the whole image #this is a check to be sure we don't send too big images through pixel_count = partial_rescale * partial_rescale * np_img.shape[ 0] * np_img.shape[1] if pixel_count > self.pixel_count_thresh: partial_rescale = math.sqrt(partial_rescale * partial_rescale * self.pixel_count_thresh / pixel_count) print('{} exceed thresh: {}: {}, new {}: {}'.format( imageName, s, pixel_count, rescaled * partial_rescale, partial_rescale * partial_rescale * np_img.shape[0] * np_img.shape[1])) s = rescaled * partial_rescale max_dim = partial_rescale * max(np_img.shape[0], np_img.shape[1]) if max_dim > self.max_dim_thresh: partial_rescale = partial_rescale * (self.max_dim_thresh / max_dim) print('{} exceed thresh: {}: {}, new {}: {}'.format( imageName, s, max_dim, rescaled * partial_rescale, partial_rescale * max(np_img.shape[0], np_img.shape[1]))) s = rescaled * partial_rescale ##tic=timeit.default_timer() #np_img = img_f.resize(np_img,(target_dim1, target_dim0)) np_img = img_f.resize( np_img, (0, 0), fx=partial_rescale, fy=partial_rescale, ) if len(np_img.shape) == 2: np_img = np_img[..., None] #add 'color' channel if self.color and np_img.shape[2] == 1: np_img = np.repeat(np_img, 3, axis=2) ##print('resize: {} [{}, {}]'.format(timeit.default_timer()-tic,np_img.shape[0],np_img.shape[1])) ##tic=timeit.default_timer() bbs, ids, numClasses, trans, groups, metadata, form_metadata = self.parseAnn( annotations, s) #trans = {i:v for i,v in enumerate(trans)} #metadata = {i:v for i,v in enumerate(metadata)} #start_of_line, end_of_line = getStartEndGT(annotations['byId'].values(),s) #Try: # table_points, table_pixels = self.getTables( # fieldBBs, # s, # np_img.shape[0], # np_img.shape[1], # annotations['samePairs']) #Except Exception as inst: # if imageName not in self.errors: # table_points=None # table_pixels=None # print(inst) # print('Table error on: '+imagePath) # self.errors.append(imageName) #pixel_gt = table_pixels ##ticTr=timeit.default_timer() if self.questions: #we need to do questions before crop to have full context #we have to relationships to get questions pairs = set() for index1, id in enumerate(ids): #updated responseBBIdList = self.getResponseBBIdList(id, annotations) for bbId in responseBBIdList: try: index2 = ids.index(bbId) pairs.add((min(index1, index2), max(index1, index2))) except ValueError: pass groups_adj = set() if groups is not None: for n0, n1 in pairs: g0 = -1 g1 = -1 for i, ns in enumerate(groups): if n0 in ns: g0 = i if g1 != -1: break if n1 in ns: g1 = i if g0 != -1: break if g0 != g1: groups_adj.add((min(g0, g1), max(g0, g1))) questions_and_answers = self.makeQuestions(bbs, trans, groups, groups_adj) else: questions_and_answers = None if self.transform is not None: if 'word_boxes' in form_metadata: word_bbs = form_metadata['word_boxes'] dif_f = bbs.shape[2] - word_bbs.shape[1] blank = np.zeros([word_bbs.shape[0], dif_f]) prep_word_bbs = np.concatenate([word_bbs, blank], axis=1)[None, ...] crop_bbs = np.concatenate([bbs, prep_word_bbs], axis=1) crop_ids = ids + [ 'word{}'.format(i) for i in range(word_bbs.shape[0]) ] else: crop_bbs = bbs crop_ids = ids out, cropPoint = self.transform( { "img": np_img, "bb_gt": crop_bbs, 'bb_auxs': crop_ids, #'word_bbs':form_metadata['word_boxes'] if 'word_boxes' in form_metadata else None #"line_gt": { # "start_of_line": start_of_line, # "end_of_line": end_of_line # }, #"point_gt": { # "table_points": table_points # }, #"pixel_gt": pixel_gt, }, cropPoint) np_img = out['img'] if 'word_boxes' in form_metadata: saw_word = False word_index = -1 for i, ii in enumerate(out['bb_auxs']): if not saw_word: if type(ii) is str and 'word' in ii: saw_word = True word_index = i else: assert 'word' in ii bbs = out['bb_gt'][:, :word_index] ids = out['bb_auxs'][:word_index] form_metadata['word_boxes'] = out['bb_gt'][0, word_index:, :8] word_ids = out['bb_auxs'][word_index:] form_metadata['word_trans'] = [ form_metadata['word_trans'][int(id[4:])] for id in word_ids ] else: bbs = out['bb_gt'] ids = out['bb_auxs'] if questions_and_answers is not None: questions = [] answers = [] questions_and_answers = [ (q, a, qids) for q, a, qids in questions_and_answers if all((i in ids) for i in qids) ] if questions_and_answers is not None: if len(questions_and_answers) > self.questions: questions_and_answers = random.sample(questions_and_answers, k=self.questions) if len(questions_and_answers) > 0: questions, answers, _ = zip(*questions_and_answers) else: return self.getitem((index + 1) % len(self)) else: questions = answers = None ##tic=timeit.default_timer() if np_img.shape[2] == 3: np_img = augmentation.apply_random_color_rotation(np_img) np_img = augmentation.apply_tensmeyer_brightness( np_img, **self.aug_params) else: np_img = augmentation.apply_tensmeyer_brightness( np_img, **self.aug_params) ##print('augmentation: {}'.format(timeit.default_timer()-tic)) newGroups = [] for group in groups: newGroup = [ids.index(bbId) for bbId in group if bbId in ids] if len(newGroup) > 0: newGroups.append(newGroup) #print(len(newGroups)-1,newGroup) groups = newGroups ##print('transfrm: {} [{}, {}]'.format(timeit.default_timer()-ticTr,org_img.shape[0],org_img.shape[1])) pairs = set() #import pdb;pdb.set_trace() numNeighbors = [0] * len(ids) for index1, id in enumerate(ids): #updated responseBBIdList = self.getResponseBBIdList(id, annotations) for bbId in responseBBIdList: try: index2 = ids.index(bbId) #adjMatrix[min(index1,index2),max(index1,index2)]=1 pairs.add((min(index1, index2), max(index1, index2))) numNeighbors[index1] += 1 except ValueError: pass #ones = torch.ones(len(pairs)) #if len(pairs)>0: # pairs = torch.LongTensor(list(pairs)).t() #else: # pairs = torch.LongTensor(pairs) #adjMatrix = torch.sparse.FloatTensor(pairs,ones,(len(ids),len(ids))) # This is an upper diagonal matrix as pairings are bi-directional #if len(np_img.shape)==2: # img=np_img[None,None,:,:] #add "color" channel and batch #else: img = np_img.transpose( [2, 0, 1])[None, ...] #from [row,col,color] to [batch,color,row,col] img = img.astype(np.float32) img = torch.from_numpy(img) img = 1.0 - img / 128.0 #ideally the median value would be 0 #if pixel_gt is not None: # pixel_gt = pixel_gt.transpose([2,0,1])[None,...] # pixel_gt = torch.from_numpy(pixel_gt) #start_of_line = None if start_of_line is None or start_of_line.shape[1] == 0 else torch.from_numpy(start_of_line) #end_of_line = None if end_of_line is None or end_of_line.shape[1] == 0 else torch.from_numpy(end_of_line) bbs = convertBBs(bbs, self.rotate, numClasses) if 'word_boxes' in form_metadata: form_metadata['word_boxes'] = convertBBs( form_metadata['word_boxes'][None, ...], self.rotate, 0)[0, ...] if len(numNeighbors) > 0: numNeighbors = torch.tensor(numNeighbors)[None, :] #add batch dim else: numNeighbors = None #if table_points is not None: # table_points = None if table_points.shape[1] == 0 else torch.from_numpy(table_points) groups_adj = set() if groups is not None: for n0, n1 in pairs: g0 = -1 g1 = -1 for i, ns in enumerate(groups): if n0 in ns: g0 = i if g1 != -1: break if n1 in ns: g1 = i if g0 != -1: break if g0 != g1: groups_adj.add((min(g0, g1), max(g0, g1))) for group in groups: for i in group: assert (i < bbs.shape[1]) targetIndexToGroup = {} for groupId, bbIds in enumerate(groups): targetIndexToGroup.update({bbId: groupId for bbId in bbIds}) transcription = [trans[id] for id in ids] return { "img": img, "bb_gt": bbs, "num_neighbors": numNeighbors, "adj": pairs, #adjMatrix, "imgName": imageName, "scale": s, "cropPoint": cropPoint, "transcription": transcription, "metadata": [metadata[id] for id in ids if id in metadata], "form_metadata": form_metadata, "gt_groups": groups, "targetIndexToGroup": targetIndexToGroup, "gt_groups_adj": groups_adj, "questions": questions, "answers": answers }
def main(resume, saveDir, numberOfImages, message, qr_size, qr_border, qr_version, gpu=None, config=None, addToConfig=None): if resume is not None: checkpoint = torch.load(resume, map_location=lambda storage, location: storage) print('loaded iteration {}'.format(checkpoint['iteration'])) loaded_iteration = checkpoint['iteration'] if config is None: config = checkpoint['config'] else: config = json.load(open(config)) for key in config.keys(): if type(config[key]) is dict: for key2 in config[key].keys(): if key2.startswith('pretrained'): config[key][key2] = None else: checkpoint = None config = json.load(open(config)) loaded_iteration = None config['optimizer_type'] = "none" config['trainer']['use_learning_schedule'] = False config['trainer']['swa'] = False if gpu is None: config['cuda'] = False else: config['cuda'] = True config['gpu'] = gpu addDATASET = False if addToConfig is not None: for add in addToConfig: addTo = config printM = 'added config[' for i in range(len(add) - 2): addTo = addTo[add[i]] printM += add[i] + '][' value = add[-1] if value == "": value = None elif value[0] == '[' and value[-1] == ']': value = value[1:-1].split('-') else: try: value = int(value) except ValueError: try: value = float(value) except ValueError: pass addTo[add[-2]] = value printM += add[-2] + ']={}'.format(value) print(printM) if (add[-2] == 'useDetections' or add[-2] == 'useDetect') and value != 'gt': addDATASET = True #config['data_loader']['batch_size']=math.ceil(config['data_loader']['batch_size']/2) else: vBatchSize = batchSize if checkpoint is not None: if 'state_dict' in checkpoint: model = eval(config['model']['arch'])(config['model']) keys = list(checkpoint['state_dict'].keys()) my_state = model.state_dict() my_keys = list(my_state.keys()) for mkey in my_keys: if mkey not in keys: checkpoint['state_dict'][mkey] = my_state[mkey] #else: # print('{} me: {}, load: {}'.format(mkey,my_state[mkey].size(),checkpoint['state_dict'][mkey].size())) for ckey in keys: if ckey not in my_keys: del checkpoint['state_dict'][ckey] model.load_state_dict(checkpoint['state_dict']) else: model = checkpoint['model'] else: model = eval(config['arch'])(config['model']) model.eval() if gpu is not None: model = model.to(gpu) #generate normal QR code qr = qrcode.QRCode( version=qr_version, error_correction=qrcode.constants.ERROR_CORRECT_L, box_size=1, border=qr_border, mask_pattern=None #self.mask_pattern, ) qr.add_data(message) qr.make(fit=True) qr_img = qr.make_image(fill_color="black", back_color="white") qr_img = np.array(qr_img) qr_img = img_f.resize(qr_img, (qr_size, qr_size), degree=0) if qr_img.max() == 255: qr_img = qr_img / 255 qr_img = qr_img * 2 - 1 qr_img = torch.from_numpy(qr_img[None, None, ...]).float() #add batch and color qr_img = qr_img.expand(numberOfImages, -1, -1, -1) with torch.no_grad(): if gpu is not None: qr_img = qr_img.to(gpu) gen_image = model(qr_img) gen_image = gen_image.clamp(-1, 1) gen_image = (gen_image + 1) / 2 gen_image = gen_image.cpu().permute(0, 2, 3, 1) for b in range(numberOfImages): path = os.path.join(saveDir, '{}.png'.format(b)) img_f.imwrite(path, gen_image[b])