Ejemplo n.º 1
0
    def __getitem__(self, idx):

        img = None
        while img is None:
            img = img_f.imread(self.index[idx])
            idx = (img + 7) % len(self.index)
        if img.shape[0] != self.size or img.shape[1] != self.size:
            img = img_f.resize(img, (self.size, self.size), degree=0)
        img = torch.from_numpy(img).permute(2, 0, 1).float()
        if img.max() >= 220:
            img = img / 255  #I think different versions of Pytorch do the conversion from bool to float differently
        if img.size(
                0) == 4:  #alpha channel, we'll fill it with a QR code image
            if self.qr_dataset is None:
                self.qr_dataset = SimpleQRDataset(None, 'train', {
                    'str_len': 17,
                    'final_size': self.size,
                    'noise': False
                })
            qr_img = self.qr_dataset[0]['image']
            qr_img = (qr_img + 1) / 2
            qr_img = qr_img.expand(3, -1, -1).clone()  #convert to color
            qr_img[:, (img[3] > 0).bool()] = img[:3, (
                img[3] > 0).bool()]  #add image where alpha is >0
            img = qr_img
        img = self.transform(img)
        img = img * 2 - 1

        return img, None
Ejemplo n.º 2
0
    def __getitem__(self, idx):

        qr = qrcode.QRCode(
            version=1,
            error_correction=self.error_level,
            box_size=self.box_size,
            border=self.border,
            mask_pattern=self.mask_pattern,
        )
        if self.str_len is not None:
            #length = random.randrange(self.min_str_len,self.str_len+1)
            length = self.str_len
            gt_char = ''.join(random.choice(self.characters) for i in range(length))
        else:
            if self.indexes is not None:
                idx = self.indexes[idx]
            gt_char = '{}'.format(idx)
        qr.add_data(gt_char)
        qr.make(fit=True)
        img = qr.make_image(fill_color="black", back_color="white")
        img = np.array(img)
        if self.final_size is not None:
            img = img_f.resize(img,(self.final_size,self.final_size),degree=0)

        # Slight noise
        if self.noise:
            img = data_utils.gaussian_noise(img.astype(np.uint)*255, max_intensity=1)

        img = torch.from_numpy(img)[None,...].float()
        if img.max() == 255:
            img=img/255 #I think different versions of Pytorch do the conversion from bool to float differently
        img = img * 2 - 1

        targetchar = torch.LongTensor(self.str_len).fill_(0)
        for i,c in enumerate(gt_char):
            targetchar[i]=self.char_to_index[c]
        targetvalid = torch.FloatTensor([1])

        if not self.mask is None:
            masked_img = self.mask * img
        else:
            masked_img = None

        if False:
            plot(img, mask=False)
            plot(masked_img, mask=False)
            plot(self.mask, mask=True)

        return {
            "image": img,
            "gt_char": gt_char,
            'targetchar': targetchar,
            'targetvalid': targetvalid,
            "masked_img": masked_img
        }
Ejemplo n.º 3
0
def makeQR(text, size=None):
    qr = qrcode.QRCode(
        version=1,
        error_correction=qrcode.constants.ERROR_CORRECT_L,
        box_size=11,
        border=2,
    )
    qr.add_data(text)
    qr.make(fit=True)

    img = qr.make_image(fill_color="black", back_color="white")
    img = np.array(img)[:, :, None].repeat(3, axis=2).astype(np.uint8) * 255
    if size is not None:
        img = cv2.resize(img, (size, size), degree=0)
    return img
Ejemplo n.º 4
0
    def generate_qr_code(self, gt_char):
        qr = qrcode.QRCode(
            version=1,
            error_correction=qrcode.constants.ERROR_CORRECT_L,
            box_size=1,
            border=self.border,
            mask_pattern=self.mask_pattern,
        )
        gt_char = str(gt_char)
        qr.add_data(gt_char)
        qr.make(fit=True)  #WARNING
        img = qr.make_image(fill_color="black",
                            back_color="white")  #.resize(self.qr_size)
        #print(np.array(img).shape)
        img = np.array(img).astype(np.uint8) * 255
        img = img_f.resize(img, (self.final_size, self.final_size), degree=0)
        img = img[:, :, None].repeat(3, axis=2)

        return {"gt_char": gt_char, "image_undistorted": img}
Ejemplo n.º 5
0
        #:'allenywang': lambda qr,a: QRMatrix('decode',image=a).decode(),
        'pyzbar': lambda qr, a: zbar_decode(a)
    }
    images = list(Path("./imagenet/images/dogs").rglob("*"))
    num_images = len(images)
    test_strings = [
        "short", "medium 4io4\:][", "long sdfjka349:fg,.<>fgok4t-={}.///gf"
    ]

    for text in test_strings:
        results = defaultdict(lambda: 0)
        avg_max_interpolation = defaultdict(list)
        qr_image = makeQR(text, 256)
        for imN, f in enumerate(images):
            background = cv2.imread(str(f))
            background = cv2.resize(background, (256, 256))
            for name, qr in qr_decoders.items():
                qr_d = qr_decode[name]

                max_hit = 0
                for p in range(20):
                    mix = round(p * .05, 2)
                    #s = superimpose(background, qr_image, mix)
                    s = contrast(qr_image, round(p * .05, 2))
                    cv2.imwrite('tmp.png', s)
                    s_ = cv2.imread('tmp.png')
                    res = qr_d(qr, s_)
                    #if mix>0.5:
                    #    import pdb;pdb.set_trace()
                    #print('{} {} : {}'.format(p*.05,name,res))
                    if res == text:
Ejemplo n.º 6
0
    def __init__(self, dirPath=None, split=None, config=None, images=None):
        super(FUNSDGraphPair, self).__init__(dirPath, split, config, images)

        self.only_types = None

        self.split_to_lines = config['split_to_lines']

        if images is not None:
            self.images = images
        else:
            if 'overfit' in config and config['overfit']:
                splitFile = 'overfit_split.json'
            else:
                splitFile = 'FUNSD_train_valid_test_split.json'
            with open(os.path.join(splitFile)) as f:
                #if split=='valid' or split=='validation':
                #    trainTest='train'
                #else:
                #    trainTest=split
                readFile = json.loads(f.read())
                if type(split) is str:
                    toUse = readFile[split]
                    imagesAndAnn = []
                    imageDir = os.path.join(dirPath, toUse['root'], 'images')
                    annDir = os.path.join(dirPath, toUse['root'],
                                          'annotations')
                    for name in toUse['images']:
                        imagesAndAnn.append(
                            (name + '.png',
                             os.path.join(imageDir, name + '.png'),
                             os.path.join(annDir, name + '.json')))
                elif type(split) is list:
                    imagesAndAnn = []
                    for spstr in split:
                        toUse = readFile[spstr]
                        imageDir = os.path.join(dirPath, toUse['root'],
                                                'images')
                        annDir = os.path.join(dirPath, toUse['root'],
                                              'annotations')
                        for name in toUse['images']:
                            imagesAndAnn.append(
                                (name + '.png',
                                 os.path.join(imageDir, name + '.png'),
                                 os.path.join(annDir, name + '.json')))
                else:
                    print("Error, unknown split {}".format(split))
                    exit()
            self.images = []
            for imageName, imagePath, jsonPath in imagesAndAnn:
                org_path = imagePath
                if self.cache_resized:
                    path = os.path.join(self.cache_path, imageName)
                else:
                    path = org_path
                if os.path.exists(jsonPath):
                    rescale = 1.0
                    if self.cache_resized:
                        rescale = self.rescale_range[1]
                        if not os.path.exists(path):
                            org_img = img_f.imread(org_path)
                            if org_img is None:
                                print('WARNING, could not read {}'.format(
                                    org_img))
                                continue
                            resized = img_f.resize(
                                org_img,
                                (0, 0),
                                fx=self.rescale_range[1],
                                fy=self.rescale_range[1],
                            )
                            img_f.imwrite(path, resized)
                    self.images.append({
                        'id':
                        imageName,
                        'imagePath':
                        path,
                        'annotationPath':
                        jsonPath,
                        'rescaled':
                        rescale,
                        'imageName':
                        imageName[:imageName.rfind('.')]
                    })
        self.only_types = None
        self.errors = []

        self.classMap = {
            'header': 16,
            'question': 17,
            'answer': 18,
            'other': 19
        }
        self.index_class_map = ['header', 'question', 'answer', 'other']
Ejemplo n.º 7
0
    def __init__(self, dirPath=None, split=None, config=None, images=None):
        super(FormsGraphPair, self).__init__(dirPath, split, config, images)

        if 'only_types' in config:
            self.only_types = config['only_types']
        else:
            self.only_types = None
        #print( self.only_types)
        if 'swap_circle' in config:
            self.swapCircle = config['swap_circle']
        else:
            self.swapCircle = False

        self.special_dataset = config[
            'special_dataset'] if 'special_dataset' in config else None
        if 'simple_dataset' in config and config['simple_dataset']:
            self.special_dataset = 'simple'

        if images is not None:
            self.images = images
        else:
            if self.special_dataset is not None:
                splitFile = self.special_dataset + '_train_valid_test_split.json'
            else:
                splitFile = 'train_valid_test_split.json'
            with open(os.path.join(dirPath, splitFile)) as f:
                readFile = json.loads(f.read())
                if type(split) is str:
                    groupsToUse = readFile[split]
                elif type(split) is list:
                    groupsToUse = {}
                    for spstr in split:
                        newGroups = readFile[spstr]
                        groupsToUse.update(newGroups)
                else:
                    print("Error, unknown split {}".format(split))
                    exit()
            self.images = []
            groupNames = list(groupsToUse.keys())
            groupNames.sort()

            for groupName in groupNames:
                imageNames = groupsToUse[groupName]

                #print('{} {}'.format(groupName, imageNames))
                #oneonly=False
                if groupName in SKIP:
                    print('Skipped group {}'.format(groupName))
                    continue
                #    if groupName in ONE_DONE:
                #        oneonly=True
                #        with open(os.path.join(dirPath,'groups',groupName,'template'+groupName+'.json')) as f:
                #            T_annotations = json.loads(f.read())
                #    else:
                for imageName in imageNames:
                    #if oneonly and T_annotations['imageFilename']!=imageName:
                    #    #print('skipped {} {}'.format(imageName,groupName))
                    #    continue
                    #elif oneonly:
                    #    print('only {} from {}'.format(imageName,groupName))
                    org_path = os.path.join(dirPath, 'groups', groupName,
                                            imageName)
                    if self.cache_resized:
                        path = os.path.join(self.cache_path, imageName)
                    else:
                        path = org_path
                    jsonPath = org_path[:org_path.rfind('.')] + '.json'
                    #print(jsonPath)
                    if os.path.exists(jsonPath):
                        rescale = 1.0
                        if self.cache_resized:
                            rescale = self.rescale_range[1]
                            if not os.path.exists(path):
                                org_img = img_f.imread(org_path)
                                if org_img is None:
                                    print('WARNING, could not read {}'.format(
                                        org_img))
                                    continue
                                #target_dim1 = self.rescale_range[1]
                                #target_dim0 = int(org_img.shape[0]/float(org_img.shape[1]) * target_dim1)
                                #resized = img_f.resize(org_img,(target_dim1, target_dim0))
                                resized = img_f.resize(
                                    org_img,
                                    (0, 0),
                                    fx=self.rescale_range[1],
                                    fy=self.rescale_range[1],
                                )
                                img_f.imwrite(path, resized)
                                #rescale = target_dim1/float(org_img.shape[1])
                        #elif self.cache_resized:
                        #print(jsonPath)
                        #with open(jsonPath) as f:
                        #    annotations = json.loads(f.read())
                        #imW = annotations['width']

                        #target_dim1 = self.rescale_range[1]
                        #rescale = target_dim1/float(imW)
                        #print('addint {}'.format(imageName))
                        self.images.append({
                            'id':
                            imageName,
                            'imagePath':
                            path,
                            'annotationPath':
                            jsonPath,
                            'rescaled':
                            rescale,
                            'imageName':
                            imageName[:imageName.rfind('.')]
                        })
                    #else:
                    #    print('couldnt find {}'.format(jsonPath))

                    # with open(path+'.json') as f:
                    #    annotations = json.loads(f.read())
                    #    imH = annotations['height']
                    #    imW = annotations['width']
                    #    #startCount=len(self.instances)
                    #    for bb in annotations['textBBs']:

        self.no_blanks = config['no_blanks'] if 'no_blanks' in config else False
        self.use_paired_class = config[
            'use_paired_class'] if 'use_paired_class' in config else False
        if 'no_print_fields' in config:
            self.no_print_fields = config['no_print_fields']
        else:
            self.no_print_fields = False
        self.no_graphics = config[
            'no_graphics'] if 'no_graphics' in config else False
        self.only_opposite_pairs = config[
            'only_opposite_pairs'] if 'only_opposite_pairs' in config else False

        self.group_only_same = config[
            'group_only_same'] if 'group_only_same' in config else False
        self.no_groups = config['no_groups'] if 'no_groups' in config else False
        assert (not self.no_groups or not self.group_only_same)
        if (self.group_only_same
                or self.no_groups) and self.only_opposite_pairs:
            print('Warning, you may want only_opposite_pairs off')
        if (self.group_only_same or self.no_groups) and not self.rotate:
            print('Warning, you may want rotation on')

        self.onlyFormStuff = False
        self.errors = []

        self.useClasses = config[
            'use_classes'] if 'use_classes' in config else []
        self.classMap = {
            'textGeneric': 13,
            'fieldGeneric': 14,
        }
        for i, clas in enumerate(self.useClasses):
            self.classMap[clas] = i + 15
        if not self.no_blanks:
            self.classMap['blank'] = 15 + len(self.useClasses)
        if self.use_paired_class:
            self.classMap['paired'] = 15 + len(
                self.useClasses) + (0 if self.no_blanks else 1)
Ejemplo n.º 8
0
    def getitem(self, index, scaleP=None, cropPoint=None):
        if self.useRandomAugProb is not None and np.random.rand(
        ) < self.useRandomAugProb and scaleP is None and cropPoint is None:
            return self.getRandomImage()
        ##ticFull=timeit.default_timer()
        imagePath = self.images[index]['imagePath']
        imageName = self.images[index]['imageName']
        annotationPath = self.images[index]['annotationPath']
        #print(annotationPath)
        rescaled = self.images[index]['rescaled']
        with open(annotationPath) as annFile:
            annotations = json.loads(annFile.read())

        ##tic=timeit.default_timer()
        np_img = img_f.imread(imagePath, 1 if self.color else 0)  #/255.0
        if np_img.max() < 200:
            np_img *= 255
        if np_img is None or np_img.shape[0] == 0:
            print("ERROR, could not open " + imagePath)
            return self.__getitem__((index + 1) % self.__len__())

        if scaleP is None:
            s = np.random.uniform(self.rescale_range[0], self.rescale_range[1])
        else:
            s = scaleP
        partial_rescale = s / rescaled
        if self.transform is None:  #we're doing the whole image
            #this is a check to be sure we don't send too big images through
            pixel_count = partial_rescale * partial_rescale * np_img.shape[
                0] * np_img.shape[1]
            if pixel_count > self.pixel_count_thresh:
                partial_rescale = math.sqrt(partial_rescale * partial_rescale *
                                            self.pixel_count_thresh /
                                            pixel_count)
                print('{} exceed thresh: {}: {}, new {}: {}'.format(
                    imageName, s, pixel_count, rescaled * partial_rescale,
                    partial_rescale * partial_rescale * np_img.shape[0] *
                    np_img.shape[1]))
                s = rescaled * partial_rescale

            max_dim = partial_rescale * max(np_img.shape[0], np_img.shape[1])
            if max_dim > self.max_dim_thresh:
                partial_rescale = partial_rescale * (self.max_dim_thresh /
                                                     max_dim)
                print('{} exceed thresh: {}: {}, new {}: {}'.format(
                    imageName, s, max_dim, rescaled * partial_rescale,
                    partial_rescale * max(np_img.shape[0], np_img.shape[1])))
                s = rescaled * partial_rescale
        #print('! dataset transformed {}'.format(np_img.shape))

        ##tic=timeit.default_timer()
        #np_img = img_f.resize(np_img,(target_dim1, target_dim0))
        if np_img is not None:
            #print('! dataset not none...')
            np_img = img_f.resize(
                np_img,
                (0, 0),
                fx=partial_rescale,
                fy=partial_rescale,
            )
        #print('! dataset resize')
        if len(np_img.shape) == 2:
            np_img = np_img[..., None]  #add 'color' channel
        if self.color and np_img.shape[2] == 1:
            np_img = np.repeat(np_img, 3, axis=2)
        ##print('resize: {}  [{}, {}]'.format(timeit.default_timer()-tic,np_img.shape[0],np_img.shape[1]))

        bbs, line_gts, point_gts, pixel_gt, numClasses, numNeighbors, pairs = self.parseAnn(
            np_img, annotations, s, imagePath)

        if self.coordConv:  #add absolute position information
            xs = 255 * np.arange(np_img.shape[1]) / (np_img.shape[1])
            xs = np.repeat(xs[None, :, None], np_img.shape[0], axis=0)
            ys = 255 * np.arange(np_img.shape[0]) / (np_img.shape[0])
            ys = np.repeat(ys[:, None, None], np_img.shape[1], axis=1)
            np_img = np.concatenate(
                (np_img, xs.astype(np_img.dtype), ys.astype(np_img.dtype)),
                axis=2)

        ##ticTr=timeit.default_timer()
        if self.transform is not None:
            pairs = None
            out, cropPoint = self.transform(
                {
                    "img": np_img,
                    "bb_gt": bbs,
                    "bb_auxs": numNeighbors,
                    "line_gt": line_gts,
                    "point_gt": point_gts,
                    "pixel_gt": pixel_gt,
                }, cropPoint)
            np_img = out['img']
            bbs = out['bb_gt']
            numNeighbors = out['bb_auxs']
            #if 'table_points' in out['point_gt']:
            #    table_points = out['point_gt']['table_points']
            #else:
            #    table_points=None
            point_gts = out['point_gt']
            pixel_gt = out['pixel_gt']
            #start_of_line = out['line_gt']['start_of_line']
            #end_of_line = out['line_gt']['end_of_line']
            line_gts = out['line_gt']

            ##tic=timeit.default_timer()
            if self.color:
                np_img[:, :, :3] = augmentation.apply_random_color_rotation(
                    np_img[:, :, :3])
                np_img[:, :, :3] = augmentation.apply_tensmeyer_brightness(
                    np_img[:, :, :3])
            else:
                np_img[:, :, 0:1] = augmentation.apply_tensmeyer_brightness(
                    np_img[:, :, 0:1])
            ##print('augmentation: {}'.format(timeit.default_timer()-tic))
        ##print('transfrm: {}  [{}, {}]'.format(timeit.default_timer()-ticTr,org_img.shape[0],org_img.shape[1]))

        #if len(np_img.shape)==2:
        #    img=np_img[None,None,:,:] #add "color" channel and batch
        #else:
        img = np_img.transpose(
            [2, 0, 1])[None,
                       ...]  #from [row,col,color] to [batch,color,row,col]
        img = img.astype(np.float32)
        img = torch.from_numpy(img)
        img = 1.0 - img / 128.0  #ideally the median value would be 0

        if pixel_gt is not None:
            pixel_gt = pixel_gt.transpose([2, 0, 1])[None, ...]
            pixel_gt = torch.from_numpy(pixel_gt)

        #start_of_line = None if start_of_line is None or start_of_line.shape[1] == 0 else torch.from_numpy(start_of_line)
        #end_of_line = None if end_of_line is None or end_of_line.shape[1] == 0 else torch.from_numpy(end_of_line)
        if line_gts is not None:
            for name in line_gts:
                line_gts[name] = None if line_gts[name] is None or line_gts[
                    name].shape[1] == 0 else torch.from_numpy(line_gts[name])

        #bbs = None if bbs.shape[1] == 0 else torch.from_numpy(bbs)
        bbs = convertBBs(bbs, self.rotate, numClasses)
        if numNeighbors is not None and len(numNeighbors) > 0:
            numNeighbors = torch.tensor(numNeighbors)[None, :]  #add batch dim
        else:
            numNeighbors = None
            #start_of_line = convertLines(start_of_line,numClasses)
        #end_of_line = convertLines(end_of_line,numClasses)
        if point_gts is not None:
            for name in point_gts:
                #if table_points is not None:
                #table_points = None if table_points.shape[1] == 0 else torch.from_numpy(table_points)
                if point_gts[name] is not None:
                    point_gts[name] = None if point_gts[name].shape[
                        1] == 0 else torch.from_numpy(point_gts[name])

        ##print('__getitem__: '+str(timeit.default_timer()-ticFull))
        if self.only_types is None:
            return {
                "img": img,
                "bb_gt": bbs,
                "num_neighbors": numNeighbors,
                "line_gt": line_gts,
                "point_gt": point_gts,
                "pixel_gt": pixel_gt,
                "imgName": imageName,
                "scale": s,
                "cropPoint": cropPoint,
                "pairs": pairs
            }
        else:
            if 'boxes' not in self.only_types or not self.only_types['boxes']:
                bbs = None
            line_gt = {}
            if 'line' in self.only_types:
                for ent in self.only_types['line']:
                    if type(ent) == list:
                        toComb = []
                        for inst in ent[1:]:
                            einst = line_gts[inst]
                            if einst is not None:
                                toComb.append(einst)
                        if len(toComb) > 0:
                            comb = torch.cat(toComb, dim=1)
                            line_gt[ent[0]] = comb
                        else:
                            line_gt[ent[0]] = None
                    else:
                        line_gt[ent] = line_gts[ent]
            point_gt = {}
            if 'point' in self.only_types:
                for ent in self.only_types['point']:
                    if type(ent) == list:
                        toComb = []
                        for inst in ent[1:]:
                            einst = point_gts[inst]
                            if einst is not None:
                                toComb.append(einst)
                        if len(toComb) > 0:
                            comb = torch.cat(toComb, dim=1)
                            point_gt[ent[0]] = comb
                        else:
                            line_gt[ent[0]] = None
                    else:
                        point_gt[ent] = point_gts[ent]
            pixel_gtR = None
            #for ent in self.only_types['pixel']:
            #    if type(ent)==list:
            #        comb = ent[1]
            #        for inst in ent[2:]:
            #            comb = (comb + inst)==2 #:eq(2) #pixel-wise AND
            #        pixel_gt[ent[0]]=comb
            #    else:
            #        pixel_gt[ent]=eval(ent)
            if 'pixel' in self.only_types:  # and self.only_types['pixel'][0]=='table_pixels':
                pixel_gtR = pixel_gt

            return {
                "img": img,
                "bb_gt": bbs,
                "num_neighbors": numNeighbors,
                "line_gt": line_gt,
                "point_gt": point_gt,
                "pixel_gt": pixel_gtR,
                "imgName": imageName,
                "scale": s,
                "cropPoint": cropPoint,
                "pairs": pairs,
            }
Ejemplo n.º 9
0
    def getitem(self, index, scaleP=None, cropPoint=None):
        ##ticFull=timeit.default_timer()
        imagePath = self.images[index]['imagePath']
        imageName = self.images[index]['imageName']
        annotationPath = self.images[index]['annotationPath']
        #print(annotationPath)
        rescaled = self.images[index]['rescaled']
        with open(annotationPath) as annFile:
            annotations = json.loads(annFile.read())

        ##tic=timeit.default_timer()
        np_img = img_f.imread(imagePath, 1 if self.color else 0)  #*255.0
        if np_img.max() < 200:
            np_img *= 255
        if np_img is None or np_img.shape[0] == 0:
            print("ERROR, could not open " + imagePath)
            return self.__getitem__((index + 1) % self.__len__())
        if scaleP is None:
            s = np.random.uniform(self.rescale_range[0], self.rescale_range[1])
        else:
            s = scaleP
        partial_rescale = s / rescaled
        if self.transform is None:  #we're doing the whole image
            #this is a check to be sure we don't send too big images through
            pixel_count = partial_rescale * partial_rescale * np_img.shape[
                0] * np_img.shape[1]
            if pixel_count > self.pixel_count_thresh:
                partial_rescale = math.sqrt(partial_rescale * partial_rescale *
                                            self.pixel_count_thresh /
                                            pixel_count)
                print('{} exceed thresh: {}: {}, new {}: {}'.format(
                    imageName, s, pixel_count, rescaled * partial_rescale,
                    partial_rescale * partial_rescale * np_img.shape[0] *
                    np_img.shape[1]))
                s = rescaled * partial_rescale

            max_dim = partial_rescale * max(np_img.shape[0], np_img.shape[1])
            if max_dim > self.max_dim_thresh:
                partial_rescale = partial_rescale * (self.max_dim_thresh /
                                                     max_dim)
                print('{} exceed thresh: {}: {}, new {}: {}'.format(
                    imageName, s, max_dim, rescaled * partial_rescale,
                    partial_rescale * max(np_img.shape[0], np_img.shape[1])))
                s = rescaled * partial_rescale

        ##tic=timeit.default_timer()
        #np_img = img_f.resize(np_img,(target_dim1, target_dim0))
        np_img = img_f.resize(
            np_img,
            (0, 0),
            fx=partial_rescale,
            fy=partial_rescale,
        )
        if len(np_img.shape) == 2:
            np_img = np_img[..., None]  #add 'color' channel
        if self.color and np_img.shape[2] == 1:
            np_img = np.repeat(np_img, 3, axis=2)
        ##print('resize: {}  [{}, {}]'.format(timeit.default_timer()-tic,np_img.shape[0],np_img.shape[1]))

        ##tic=timeit.default_timer()

        bbs, ids, numClasses, trans, groups, metadata, form_metadata = self.parseAnn(
            annotations, s)
        #trans = {i:v for i,v in enumerate(trans)}
        #metadata = {i:v for i,v in enumerate(metadata)}

        #start_of_line, end_of_line = getStartEndGT(annotations['byId'].values(),s)
        #Try:
        #    table_points, table_pixels = self.getTables(
        #            fieldBBs,
        #            s,
        #            np_img.shape[0],
        #            np_img.shape[1],
        #            annotations['samePairs'])
        #Except Exception as inst:
        #    if imageName not in self.errors:
        #        table_points=None
        #        table_pixels=None
        #        print(inst)
        #        print('Table error on: '+imagePath)
        #        self.errors.append(imageName)

        #pixel_gt = table_pixels

        ##ticTr=timeit.default_timer()
        if self.questions:  #we need to do questions before crop to have full context
            #we have to relationships to get questions
            pairs = set()
            for index1, id in enumerate(ids):  #updated
                responseBBIdList = self.getResponseBBIdList(id, annotations)
                for bbId in responseBBIdList:
                    try:
                        index2 = ids.index(bbId)
                        pairs.add((min(index1, index2), max(index1, index2)))
                    except ValueError:
                        pass
            groups_adj = set()
            if groups is not None:
                for n0, n1 in pairs:
                    g0 = -1
                    g1 = -1
                    for i, ns in enumerate(groups):
                        if n0 in ns:
                            g0 = i
                            if g1 != -1:
                                break
                        if n1 in ns:
                            g1 = i
                            if g0 != -1:
                                break
                    if g0 != g1:
                        groups_adj.add((min(g0, g1), max(g0, g1)))
            questions_and_answers = self.makeQuestions(bbs, trans, groups,
                                                       groups_adj)
        else:
            questions_and_answers = None

        if self.transform is not None:
            if 'word_boxes' in form_metadata:
                word_bbs = form_metadata['word_boxes']
                dif_f = bbs.shape[2] - word_bbs.shape[1]
                blank = np.zeros([word_bbs.shape[0], dif_f])
                prep_word_bbs = np.concatenate([word_bbs, blank], axis=1)[None,
                                                                          ...]
                crop_bbs = np.concatenate([bbs, prep_word_bbs], axis=1)
                crop_ids = ids + [
                    'word{}'.format(i) for i in range(word_bbs.shape[0])
                ]
            else:
                crop_bbs = bbs
                crop_ids = ids
            out, cropPoint = self.transform(
                {
                    "img": np_img,
                    "bb_gt": crop_bbs,
                    'bb_auxs': crop_ids,
                    #'word_bbs':form_metadata['word_boxes'] if 'word_boxes' in form_metadata else None
                    #"line_gt": {
                    #    "start_of_line": start_of_line,
                    #    "end_of_line": end_of_line
                    #    },
                    #"point_gt": {
                    #        "table_points": table_points
                    #        },
                    #"pixel_gt": pixel_gt,
                },
                cropPoint)
            np_img = out['img']

            if 'word_boxes' in form_metadata:
                saw_word = False
                word_index = -1
                for i, ii in enumerate(out['bb_auxs']):
                    if not saw_word:
                        if type(ii) is str and 'word' in ii:
                            saw_word = True
                            word_index = i
                    else:
                        assert 'word' in ii
                bbs = out['bb_gt'][:, :word_index]
                ids = out['bb_auxs'][:word_index]
                form_metadata['word_boxes'] = out['bb_gt'][0, word_index:, :8]
                word_ids = out['bb_auxs'][word_index:]
                form_metadata['word_trans'] = [
                    form_metadata['word_trans'][int(id[4:])] for id in word_ids
                ]
            else:
                bbs = out['bb_gt']
                ids = out['bb_auxs']

            if questions_and_answers is not None:
                questions = []
                answers = []
                questions_and_answers = [
                    (q, a, qids) for q, a, qids in questions_and_answers
                    if all((i in ids) for i in qids)
                ]
        if questions_and_answers is not None:
            if len(questions_and_answers) > self.questions:
                questions_and_answers = random.sample(questions_and_answers,
                                                      k=self.questions)
            if len(questions_and_answers) > 0:
                questions, answers, _ = zip(*questions_and_answers)
            else:
                return self.getitem((index + 1) % len(self))
        else:
            questions = answers = None

            ##tic=timeit.default_timer()
            if np_img.shape[2] == 3:
                np_img = augmentation.apply_random_color_rotation(np_img)
                np_img = augmentation.apply_tensmeyer_brightness(
                    np_img, **self.aug_params)
            else:
                np_img = augmentation.apply_tensmeyer_brightness(
                    np_img, **self.aug_params)
            ##print('augmentation: {}'.format(timeit.default_timer()-tic))
        newGroups = []
        for group in groups:
            newGroup = [ids.index(bbId) for bbId in group if bbId in ids]
            if len(newGroup) > 0:
                newGroups.append(newGroup)
                #print(len(newGroups)-1,newGroup)
        groups = newGroups
        ##print('transfrm: {}  [{}, {}]'.format(timeit.default_timer()-ticTr,org_img.shape[0],org_img.shape[1]))
        pairs = set()
        #import pdb;pdb.set_trace()
        numNeighbors = [0] * len(ids)
        for index1, id in enumerate(ids):  #updated
            responseBBIdList = self.getResponseBBIdList(id, annotations)
            for bbId in responseBBIdList:
                try:
                    index2 = ids.index(bbId)
                    #adjMatrix[min(index1,index2),max(index1,index2)]=1
                    pairs.add((min(index1, index2), max(index1, index2)))
                    numNeighbors[index1] += 1
                except ValueError:
                    pass
        #ones = torch.ones(len(pairs))
        #if len(pairs)>0:
        #    pairs = torch.LongTensor(list(pairs)).t()
        #else:
        #    pairs = torch.LongTensor(pairs)
        #adjMatrix = torch.sparse.FloatTensor(pairs,ones,(len(ids),len(ids))) # This is an upper diagonal matrix as pairings are bi-directional

        #if len(np_img.shape)==2:
        #    img=np_img[None,None,:,:] #add "color" channel and batch
        #else:
        img = np_img.transpose(
            [2, 0, 1])[None,
                       ...]  #from [row,col,color] to [batch,color,row,col]
        img = img.astype(np.float32)
        img = torch.from_numpy(img)
        img = 1.0 - img / 128.0  #ideally the median value would be 0
        #if pixel_gt is not None:
        #    pixel_gt = pixel_gt.transpose([2,0,1])[None,...]
        #    pixel_gt = torch.from_numpy(pixel_gt)

        #start_of_line = None if start_of_line is None or start_of_line.shape[1] == 0 else torch.from_numpy(start_of_line)
        #end_of_line = None if end_of_line is None or end_of_line.shape[1] == 0 else torch.from_numpy(end_of_line)

        bbs = convertBBs(bbs, self.rotate, numClasses)
        if 'word_boxes' in form_metadata:
            form_metadata['word_boxes'] = convertBBs(
                form_metadata['word_boxes'][None, ...], self.rotate, 0)[0, ...]
        if len(numNeighbors) > 0:
            numNeighbors = torch.tensor(numNeighbors)[None, :]  #add batch dim
        else:
            numNeighbors = None
        #if table_points is not None:
        #    table_points = None if table_points.shape[1] == 0 else torch.from_numpy(table_points)
        groups_adj = set()
        if groups is not None:
            for n0, n1 in pairs:
                g0 = -1
                g1 = -1
                for i, ns in enumerate(groups):
                    if n0 in ns:
                        g0 = i
                        if g1 != -1:
                            break
                    if n1 in ns:
                        g1 = i
                        if g0 != -1:
                            break
                if g0 != g1:
                    groups_adj.add((min(g0, g1), max(g0, g1)))
            for group in groups:
                for i in group:
                    assert (i < bbs.shape[1])
            targetIndexToGroup = {}
            for groupId, bbIds in enumerate(groups):
                targetIndexToGroup.update({bbId: groupId for bbId in bbIds})

        transcription = [trans[id] for id in ids]

        return {
            "img": img,
            "bb_gt": bbs,
            "num_neighbors": numNeighbors,
            "adj": pairs,  #adjMatrix,
            "imgName": imageName,
            "scale": s,
            "cropPoint": cropPoint,
            "transcription": transcription,
            "metadata": [metadata[id] for id in ids if id in metadata],
            "form_metadata": form_metadata,
            "gt_groups": groups,
            "targetIndexToGroup": targetIndexToGroup,
            "gt_groups_adj": groups_adj,
            "questions": questions,
            "answers": answers
        }
Ejemplo n.º 10
0
def main(resume,
         saveDir,
         numberOfImages,
         message,
         qr_size,
         qr_border,
         qr_version,
         gpu=None,
         config=None,
         addToConfig=None):
    if resume is not None:
        checkpoint = torch.load(resume,
                                map_location=lambda storage, location: storage)
        print('loaded iteration {}'.format(checkpoint['iteration']))
        loaded_iteration = checkpoint['iteration']
        if config is None:
            config = checkpoint['config']
        else:
            config = json.load(open(config))
        for key in config.keys():
            if type(config[key]) is dict:
                for key2 in config[key].keys():
                    if key2.startswith('pretrained'):
                        config[key][key2] = None
    else:
        checkpoint = None
        config = json.load(open(config))
        loaded_iteration = None
    config['optimizer_type'] = "none"
    config['trainer']['use_learning_schedule'] = False
    config['trainer']['swa'] = False
    if gpu is None:
        config['cuda'] = False
    else:
        config['cuda'] = True
        config['gpu'] = gpu
    addDATASET = False
    if addToConfig is not None:
        for add in addToConfig:
            addTo = config
            printM = 'added config['
            for i in range(len(add) - 2):
                addTo = addTo[add[i]]
                printM += add[i] + ']['
            value = add[-1]
            if value == "":
                value = None
            elif value[0] == '[' and value[-1] == ']':
                value = value[1:-1].split('-')
            else:
                try:
                    value = int(value)
                except ValueError:
                    try:
                        value = float(value)
                    except ValueError:
                        pass
            addTo[add[-2]] = value
            printM += add[-2] + ']={}'.format(value)
            print(printM)
            if (add[-2] == 'useDetections'
                    or add[-2] == 'useDetect') and value != 'gt':
                addDATASET = True

    #config['data_loader']['batch_size']=math.ceil(config['data_loader']['batch_size']/2)
    else:
        vBatchSize = batchSize

    if checkpoint is not None:
        if 'state_dict' in checkpoint:
            model = eval(config['model']['arch'])(config['model'])
            keys = list(checkpoint['state_dict'].keys())
            my_state = model.state_dict()
            my_keys = list(my_state.keys())
            for mkey in my_keys:
                if mkey not in keys:
                    checkpoint['state_dict'][mkey] = my_state[mkey]
                #else:
                #    print('{} me: {}, load: {}'.format(mkey,my_state[mkey].size(),checkpoint['state_dict'][mkey].size()))
            for ckey in keys:
                if ckey not in my_keys:
                    del checkpoint['state_dict'][ckey]
            model.load_state_dict(checkpoint['state_dict'])
        else:
            model = checkpoint['model']
    else:
        model = eval(config['arch'])(config['model'])
    model.eval()
    if gpu is not None:
        model = model.to(gpu)

    #generate normal QR code
    qr = qrcode.QRCode(
        version=qr_version,
        error_correction=qrcode.constants.ERROR_CORRECT_L,
        box_size=1,
        border=qr_border,
        mask_pattern=None  #self.mask_pattern,
    )
    qr.add_data(message)
    qr.make(fit=True)
    qr_img = qr.make_image(fill_color="black", back_color="white")
    qr_img = np.array(qr_img)
    qr_img = img_f.resize(qr_img, (qr_size, qr_size), degree=0)
    if qr_img.max() == 255:
        qr_img = qr_img / 255
    qr_img = qr_img * 2 - 1

    qr_img = torch.from_numpy(qr_img[None, None,
                                     ...]).float()  #add batch and color
    qr_img = qr_img.expand(numberOfImages, -1, -1, -1)

    with torch.no_grad():
        if gpu is not None:
            qr_img = qr_img.to(gpu)
        gen_image = model(qr_img)
        gen_image = gen_image.clamp(-1, 1)
        gen_image = (gen_image + 1) / 2
        gen_image = gen_image.cpu().permute(0, 2, 3, 1)
        for b in range(numberOfImages):
            path = os.path.join(saveDir, '{}.png'.format(b))
            img_f.imwrite(path, gen_image[b])