예제 #1
0
    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        imgpath = self.lines[index].rstrip()

        if self.train and index % 64== 0:
            if self.seen < 4000*64:
                width = 13*32
                self.shape = (width, width)
            elif self.seen < 8000*64:
                width = (random.randint(0,3) + 13)*32
                self.shape = (width, width)
            elif self.seen < 12000*64:
                width = (random.randint(0,5) + 12)*32
                self.shape = (width, width)
            elif self.seen < 16000*64:
                width = (random.randint(0,7) + 11)*32
                self.shape = (width, width)
            else: # self.seen < 20000*64:
                width = (random.randint(0,9) + 10)*32
                self.shape = (width, width)
            self.shape = (416, 416)

        if self.train:
            jitter = 0.2
            hue = 0.1
            saturation = 1.5 
            exposure = 1.5

            img, label = load_data_detection(imgpath, self.shape, jitter, hue, saturation, exposure)
            label = torch.from_numpy(label)
        else:
            img = Image.open(imgpath).convert('RGB')
            if self.shape:
                img = img.resize(self.shape)
            
            labpath = imgpath.replace('train_x/', 'train_y/').replace('.png','.txt')
            label = torch.zeros(50*5)
            #if os.path.getsize(labpath):
            #tmp = torch.from_numpy(np.loadtxt(labpath))
            try:
                tmp = torch.from_numpy(read_truths_args(labpath, 8.0/img.width).astype('float32'))
            except Exception:
                tmp = torch.zeros(1,5)
            #tmp = torch.from_numpy(read_truths(labpath))
            tmp = tmp.view(-1)
            tsz = tmp.numel()
            #print('labpath = %s , tsz = %d' % (labpath, tsz))
            if tsz > 50*5:
                label = tmp[0:50*5]
            elif tsz > 0:
                label[0:tsz] = tmp

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            label = self.target_transform(label)

        self.seen = self.seen + self.num_workers
        return (img, label)
예제 #2
0
    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        imgpath = self.lines[index].rstrip()

        if not self.train:
            jitter = 0.2
            hue = 0.1
            saturation = 1.5 
            exposure = 1.5

            img, label = load_data_detection(imgpath, self.shape, jitter, hue, saturation, exposure)
            label = torch.from_numpy(label)
        else:
            #print('===============imgpath==================',imgpath)  #连续执行一个batch的img!!!!!!!!!!!1
            img = Image.open(imgpath).convert('RGB')
            
            #if self.shape:
            #    img = img.resize(self.shape)
            img_name=imgpath.split('/')[-1].strip()
            #labpath = imgpath.replace('images', 'labels').replace('JPEGImages', 'labels').replace('.jpg', '.txt').replace('.png','.txt')
            labpath = imgpath.replace('/img4/', '/img4_pred_label/').replace('.jpg', '.txt')
            #print('===============labpath==================',labpath)
            #1、随机块替换,输入:原图像、及label路径;输出替换后图像及对应mask
            #print('====================================================',img.size[0],img.size[1])
            
            img_rep=rand_repl2(img,labpath)
            img=Image.fromarray(img_rep.astype('uint8')).convert('RGB')
            #这里需要注意的是上面是Image操作的,而修复需要转为opencv的
        
            #2、块修复,输入如上输出,返回修复后图像
            #img2 = cv2.imread(imgpath)
            #img=patch_repa(img_rep,img_mask)
            
            #原label不变
            
            label = torch.zeros(50*5)
            #if os.path.getsize(labpath):
            #tmp = torch.from_numpy(np.loadtxt(labpath))
            try:
                tmp = torch.from_numpy(read_truths_args(labpath, 8.0/img.width).astype('float32'))
            except Exception:
                tmp = torch.zeros(1,5)
            #tmp = torch.from_numpy(read_truths(labpath))
            tmp = tmp.view(-1)
            tsz = tmp.numel()
            #print('labpath = %s , tsz = %d' % (labpath, tsz))
            if tsz > 50*5:
                label = tmp[0:50*5]
            elif tsz > 0:
                label[0:tsz] = tmp

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            label = self.target_transform(label)

        self.seen = self.seen + self.num_workers
        return (img_name,img, label)
예제 #3
0
    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        imgpath = self.lines[index].rstrip()

        if self.train:
            if self.seen % (
                    self.batch_size * 10
            ) == 0:  # in paper, every 10 batches, but we did every 64 images
                self.shape = self.get_different_scale()
            img, label = load_data_detection(imgpath, self.shape, self.crop,
                                             self.jitter, self.hue,
                                             self.saturation, self.exposure)
            label = torch.from_numpy(label)
        else:
            img = Image.open(imgpath).convert('RGB')
            if self.shape:
                img, org_w, org_h = letterbox_image(
                    img, self.shape[0], self.shape[1]), img.width, img.height

            labpath = imgpath.replace('images', 'labels').replace(
                'JPEGImages',
                'labels').replace('.jpg', '.txt').replace('.png', '.txt')
            label = torch.zeros(50 * 5)
            #if os.path.getsize(labpath):
            #tmp = torch.from_numpy(np.loadtxt(labpath))
            try:
                tmp = torch.from_numpy(
                    read_truths_args(labpath,
                                     8.0 / img.width).astype('float32'))
            except Exception:
                tmp = torch.zeros(1, 5)
            #tmp = torch.from_numpy(read_truths(labpath))
            tmp = tmp.view(-1)
            tsz = tmp.numel()
            #print('labpath = %s , tsz = %d' % (labpath, tsz))
            if tsz > 50 * 5:
                label = tmp[0:50 * 5]
            elif tsz > 0:
                label[0:tsz] = tmp

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            label = self.target_transform(label)

        self.seen = self.seen + self.num_workers
        if self.train:
            return (img, label)
        else:
            # return (img, label, org_w, org_h)
            return (img, label, org_w, org_h, imgpath)
예제 #4
0
    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        imgp = self.lines[index].rstrip()
        imgpath = imgp + '.jpg'
        imgpath = os.path.join(self.imgdirpath, imgpath)

        if self.train:
            jitter = 0.2
            hue = 0.1
            saturation = 1.5
            exposure = 1.5

            img, label = load_data_detection(imgpath, self.shape, jitter, hue,
                                             saturation, exposure,
                                             self.labdirpath, imgp)
            label = torch.from_numpy(label).float()
        elif self.test_txt:
            img = Image.open(imgpath).convert('RGB')
            if self.shape:
                img = img.resize(self.shape)
            if self.transform is not None:
                img = self.transform(img)
            return img, imgp
        else:
            img = Image.open(imgpath).convert('RGB')
            if self.shape:
                img = img.resize(self.shape)

            labpath = imgp + '.txt'
            labpath = os.path.join(self.labdirpath, labpath)
            label = torch.zeros(50 * 5)
            try:
                tmp = torch.from_numpy(
                    read_truths_args(labpath,
                                     6.0 / img.width).astype('float32'))
            except Exception:
                print(' No target !!!!!!!!!')
                tmp = torch.zeros(1, 5)
            tmp = tmp.view(-1)
            tsz = tmp.numel()
            if tsz > 50 * 5:
                label = tmp[0:50 * 5]
            elif tsz > 0:
                label[0:tsz] = tmp

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            label = self.target_transform(label)

        return (img, label)
예제 #5
0
def test_loader(dataset_dir=data_dir):
    filenames = glob.glob('{}/images/*.jpg'.format(dataset_dir))
    #filenames = [dataset_dir+'/images/2005.jpg']

    loader = []
    for filename in filenames:
        image_raw, data = preprocessor.process(filename)
        shape_orig_WH = image_raw.size
        labfile = filename.replace('images', 'labels').replace('.jpg', '.txt')
        try:
            target = torch.from_numpy(
                read_truths_args(labfile, 8.0 / 416).astype('float32'))
        except Exception:
            target = None  #torch.zeros(1,5)
        loader.append([data, target, shape_orig_WH])
    return loader
예제 #6
0
    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        imgpath = self.parent_dir + self.lines[index].rstrip()
        if self.train:
            # jitter = 0.2
            jitter = 0.1
            hue = 0.05
            saturation = 1.5
            exposure = 1.5

            # Get background image path
            random_bg_index = random.randint(0, len(self.bg_file_names) - 1)
            bgpath = self.bg_file_names[random_bg_index]

            img, label = load_data_detection(imgpath, self.shape, jitter, hue,
                                             saturation, exposure, bgpath)
            label = torch.from_numpy(label)
        else:
            img = Image.open(imgpath).convert('RGB')
            if self.shape:
                img = img.resize(self.shape)

            labpath = imgpath.replace('benchvise', self.objclass).replace(
                'images',
                'labels_occlusion').replace('JPEGImages',
                                            'labels_occlusion').replace(
                                                '.jpg', '.txt').replace(
                                                    '.png', '.txt')
            label = torch.zeros(50 * 21)
            if os.path.getsize(labpath):
                ow, oh = img.size
                tmp = torch.from_numpy(read_truths_args(labpath))
                tmp = tmp.view(-1)
                tsz = tmp.numel()
                if tsz > 50 * 21:
                    label = tmp[0:50 * 21]
                elif tsz > 0:
                    label[0:tsz] = tmp

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            label = self.target_transform(label)

        self.seen = self.seen + self.num_workers
        return (img, label)
예제 #7
0
    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        imgpath = self.lines[index].rstrip()

        if self.train:
            jitter = 0.2
            hue = 0.1
            saturation = 1.5
            exposure = 1.5

            img, label = load_data_detection(imgpath, self.shape, jitter, hue,
                                             saturation, exposure)
            label = torch.from_numpy(label)
        else:
            img = Image.open(imgpath).convert('RGB')
            if self.shape:
                img = img.resize(self.shape)

            labpath = imgpath.replace('images', 'labels').replace(
                'JPEGImages',
                'labels').replace('.jpg', '.txt').replace('.png', '.txt')
            label = torch.zeros(800 * 5)
            try:
                tmp = torch.from_numpy(
                    read_truths_args(labpath, 5.0 / img.width,
                                     self.shape).astype('float32'))
            except Exception:
                tmp = torch.zeros(1, 5)
            tmp = tmp.view(-1)
            tsz = tmp.numel()
            if tsz > 800 * 5:
                label = tmp[0:800 * 5]
            elif tsz > 0:
                label[0:tsz] = tmp

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            label = self.target_transform(label)

        self.seen = self.seen + self.num_workers
        return (img, label)
예제 #8
0
    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        imgpath = self.lines[index].rstrip()

        if self.train:
            img, label = load_data_detection(imgpath, self.shape, self.jitter,
                                             self.hue, self.saturation,
                                             self.exposure)
            label = torch.from_numpy(label)
        else:
            img = Image.open(imgpath).convert('RGB')
            if self.shape:
                img, org_w, org_h = letter_image(
                    img, self.shape[0], self.shape[1]), img.width, img.height
            labpath = imgpath.replace('images', 'labels').replace('JPEGImages', 'labels')\
                .replace('.jpg', '.txt').replace('.png', '.txt')
            # one image at most has 50 bounding boxes, we need to have a fixed size
            label = torch.zeros(50 * 5)
            try:
                tmp = torch.from_numpy(
                    read_truths_args(labpath,
                                     8.0 / img.width).astype('float32'))
            except Exception:
                tmp = torch.zeros(1, 5)
            tmp = tmp.view(-1)
            tsz = tmp.numel()  # element number
            if tsz > 50 * 5:
                label = tmp[0:50 * 5]
            elif tsz > 0:
                label[0:tsz] = tmp

        if self.transform:
            img = self.transform(img)

        if self.train:
            return img, label
        else:
            # we need to transfer image to original size to evaluate perfoemance
            # so neet to record original size
            return img, label, org_w, org_h
예제 #9
0
    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        imgpath = self.parent_dir + self.lines[index].rstrip()

        img = Image.open(imgpath).convert('RGB')
        if self.shape:
            img = img.resize(self.shape)
        label = torch.zeros(50 * 21)
        k = 0
        for obj in self.objs:
            labpath = imgpath.replace('benchvise', obj).replace(
                'images',
                'labels_occlusion').replace('JPEGImages',
                                            'labels_occlusion').replace(
                                                '.jpg', '.txt').replace(
                                                    '.png', '.txt')
            if os.path.getsize(labpath):
                ow, oh = img.size
                tmp = torch.from_numpy(read_truths_args(labpath))
                tmp = tmp.view(-1)
                tsz = tmp.numel()
                #print(tmp[0])
                if tsz > 50 * 21:
                    label = tmp[0:50 * 21]
                elif tsz > 0:
                    label[k:k + tsz] = tmp
                k = k + tsz

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            label = self.target_transform(label)

        self.seen = self.seen + self.num_workers
        return (img, label)
예제 #10
0
    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        imgpath = self.lines[index].rstrip()

        if self.train and index % 64 == 0:
            if self.seen < 4000 * 64:
                width = 13 * 32
                self.shape = (width, width)
            elif self.seen < 8000 * 64:
                width = (random.randint(0, 3) + 13) * 32
                self.shape = (width, width)
            elif self.seen < 12000 * 64:
                width = (random.randint(0, 5) + 12) * 32
                self.shape = (width, width)
            elif self.seen < 16000 * 64:
                width = (random.randint(0, 7) + 11) * 32
                self.shape = (width, width)
            else:  # self.seen < 20000*64:
                width = (random.randint(0, 9) + 10) * 32
                self.shape = (width, width)

        if self.train:
            jitter = 0.2
            hue = 0.1
            saturation = 1.5
            exposure = 1.5

            img, label = load_data_detection(imgpath, self.shape, jitter, hue,
                                             saturation, exposure)
            label = torch.from_numpy(label)
        else:
            img = Image.open(imgpath).convert('RGB')
            if self.shape:
                img = img.resize(self.shape)

            #labpath = imgpath.replace('images', 'labels').replace('JPEGImages', 'labels').replace('.jpg', '.txt').replace('.png','.txt')
            #注意这里的labelpath路径是不变的,但是imgpath是变了的,所以这里不能只是简单替换,labelpath可由配置文件传入
            file = imgpath.split('/')[-1]
            labpath = self.label_path + file.replace('.jpg', '.txt').replace(
                '.png', '.txt').replace('.JPEG', '.txt')
            #print('-----------------',labpath)
            label = torch.zeros(50 * 5)
            #if os.path.getsize(labpath):
            #tmp = torch.from_numpy(np.loadtxt(labpath))
            #try:
            tmp = torch.from_numpy(
                read_truths_args(labpath, 8.0 / img.width).astype('float32'))
            #except Exception:
            #    tmp = torch.zeros(1,5)
            #tmp = torch.from_numpy(read_truths(labpath))
            tmp = tmp.view(-1)
            tsz = tmp.numel()
            #print('labpath = %s , tsz = %d' % (labpath, tsz))
            if tsz > 50 * 5:
                label = tmp[0:50 * 5]
            elif tsz > 0:
                label[0:tsz] = tmp

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            label = self.target_transform(label)

        self.seen = self.seen + self.num_workers
        #print('--------------------------labpath',labpath,label)
        return (img, label)
예제 #11
0
    def __getitem__(self, index):

        # Ensure the index is smallet than the number of samples in the dataset, otherwise return error
        assert index <= len(self), 'index range error'

        # Get the image path
        imgpath = self.lines[index].rstrip()

        # Decide which size you are going to resize the image depending on the iteration
        if self.train and index % self.batch_size == 0:
            if self.seen < 400 * self.batch_size:
                width = 13 * self.cell_size
                self.shape = (width, width)
            elif self.seen < 800 * self.batch_size:
                width = (random.randint(0, 7) + 13) * self.cell_size
                self.shape = (width, width)
            elif self.seen < 1200 * self.batch_size:
                width = (random.randint(0, 9) + 12) * self.cell_size
                self.shape = (width, width)
            elif self.seen < 1600 * self.batch_size:
                width = (random.randint(0, 11) + 11) * self.cell_size
                self.shape = (width, width)
            elif self.seen < 2000 * self.batch_size:
                width = (random.randint(0, 13) + 10) * self.cell_size
                self.shape = (width, width)
            elif self.seen < 2400 * self.batch_size:
                width = (random.randint(0, 15) + 9) * self.cell_size
                self.shape = (width, width)
            elif self.seen < 3000 * self.batch_size:
                width = (random.randint(0, 17) + 8) * self.cell_size
                self.shape = (width, width)
            else:
                width = (random.randint(0, 19) + 7) * self.cell_size
                self.shape = (width, width)

        if self.train:
            # If you are going to train, decide on how much data augmentation you are going to apply
            # jitter = 0.2
            # hue = 0.1
            # saturation = 1.5
            # exposure = 1.5

            jitter = 0.2
            hue = 0.1
            saturation = 1.5
            exposure = 1

            # Get background image path
            random_bg_index = random.randint(0, len(self.bg_file_names) - 1)
            bgpath = self.bg_file_names[random_bg_index]

            # Get the data augmented image and their corresponding labels
            img, label = load_data_detection(imgpath, self.shape, jitter, hue,
                                             saturation, exposure, bgpath)

            # # Save to see the dataset
            # import os
            # import time
            # def mkdir(path):
            #     path = path.strip()
            #     path = path.rstrip("/")
            #     isExists = os.path.exists(path)
            #     if not isExists:
            #         os.makedirs(path)
            #         print(path + ' created successfully')
            #         return True
            #     else:
            #         print(path + ' already exist')
            #         return False
            # mkdir('debug/4')
            # imgpath = 'debug/4/{}.png'.format(time.time())
            # img.save(imgpath)
            # print("{} saved!".format(imgpath))

            # Convert the labels to PyTorch variables
            label = torch.from_numpy(label)

        else:
            # Get the validation image, resize it to the network input size
            img = Image.open(imgpath).convert('RGB')
            if self.shape:
                img = img.resize(self.shape)

            # Read the validation labels, allow upto 50 ground-truth objects in an image
            labpath = imgpath.replace('images', 'labels').replace(
                'JPEGImages',
                'labels').replace('.jpg', '.txt').replace('.png', '.txt')
            label = torch.zeros(50 * 21)
            if os.path.getsize(labpath):
                ow, oh = img.size
                tmp = torch.from_numpy(read_truths_args(labpath))
                tmp = tmp.view(-1)
                tsz = tmp.numel()
                if tsz > 50 * 21:
                    label = tmp[0:50 * 21]
                elif tsz > 0:
                    label[0:tsz] = tmp

        # Tranform the image data to PyTorch tensors
        if self.transform is not None:
            img = self.transform(img)

        # If there is any PyTorch-specific transformation, transform the label data
        if self.target_transform is not None:
            label = self.target_transform(label)

        # Increase the number of seen examples
        self.seen = self.seen + self.num_workers

        # Return the retrieved image and its corresponding label
        return (img, label)
예제 #12
0
    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        imgpath = self.lines[index].rstrip()
        ''' Fix the width to be 13*32=416 and do not randomize
        '''
        width = 13 * 32

        if self.train and index % 64 == 0:
            if self.seen < 4000 * 64:
                width = 13 * 32
                self.shape = (width, width)
            elif self.seen < 8000 * 64:
                width = (random.randint(0, 3) + 13) * 32
                self.shape = (width, width)
            elif self.seen < 12000 * 64:
                width = (random.randint(0, 5) + 12) * 32
                self.shape = (width, width)
            elif self.seen < 16000 * 64:
                width = (random.randint(0, 7) + 11) * 32
                self.shape = (width, width)
            else:  # self.seen < 20000*64:
                width = (random.randint(0, 9) + 10) * 32
                self.shape = (width, width)

        if self.train:
            jitter = 0.2
            hue = 0.1
            saturation = 1.5
            exposure = 1.5

            img, label = load_data_detection(imgpath, self.shape, jitter, hue,
                                             saturation, exposure)
            label = torch.from_numpy(label)
        else:
            img = Image.open(imgpath).convert('RGB')
            if self.shape:
                img = img.resize(self.shape)

            totensor_transform = transforms.ToTensor()
            img_tensor = totensor_transform(img)

            noise = np.load(self.noise_path)
            img_tensor[:, 5:5 + patchSize, 5:5 + patchSize] = torch.from_numpy(
                noise[:, 5:5 + patchSize, 5:5 + patchSize])
            # For KITTI
            # img_tensor[:, 5+50:5+50+100, 158:158+100] = torch.from_numpy(noise[:, 5+50:5+50+100, 158:158+100])
            # img_tensor = img_tensor + torch.from_numpy(noise)
            img_tensor = torch.clamp(img_tensor, 0, 1)

            labpath = imgpath.replace('images', 'labels').replace(
                'JPEGImages',
                'labels').replace('.jpg', '.txt').replace('.png', '.txt')
            # print(labpath)
            #labpath = imgpath.replace('images', 'labels').replace('train', 'labels').replace('.jpg', '.txt').replace('.png','.txt')
            label = torch.zeros(50 * 5)
            #if os.path.getsize(labpath):
            #tmp = torch.from_numpy(np.loadtxt(labpath))
            try:
                tmp = torch.from_numpy(
                    read_truths_args(labpath,
                                     8.0 / img.width).astype('float32'))
            except Exception:
                tmp = torch.zeros(1, 5)
            #tmp = torch.from_numpy(read_truths(labpath))
            tmp = tmp.view(-1)
            tsz = tmp.numel()
            #print('labpath = %s , tsz = %d' % (labpath, tsz))
            if tsz > 50 * 5:
                label = tmp[0:50 * 5]
            elif tsz > 0:
                label[0:tsz] = tmp

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            label = self.target_transform(label)

        self.seen = self.seen + self.num_workers
        return (img_tensor, label, imgpath)
예제 #13
0
    def __init__(self,
                 root,
                 shape=None,
                 shuffle=True,
                 transform=None,
                 target_transform=None,
                 train=False,
                 seen=0,
                 batch_size=64,
                 num_workers=4,
                 cell_size=32,
                 bg_file_names=None,
                 num_keypoints=9,
                 max_num_gt=50,
                 cam_params=None,
                 corners3D=None):

        # root             : list of training or test images
        # shape            : shape of the image input to the network
        # shuffle          : whether to shuffle or not
        # tranform         : any pytorch-specific transformation to the input image
        # target_transform : any pytorch-specific tranformation to the target output
        # train            : whether it is training data or test data
        # seen             : the number of visited examples (iteration of the batch x batch size) # TODO: check if this is correctly assigned
        # batch_size       : how many examples there are in the batch
        # num_workers      : check what this is
        # bg_file_names    : the filenames for images from which you assign random backgrounds

        # read the the list of dataset images
        with open(root, 'r') as file:
            self.lines = file.readlines()

        if cam_params is not None:
            # reject any gt's that are not in FOV!! (or are too far away)
            keep_list = []
            box_gt = []

            def truths_length(truths, max_num_gt=50):
                for i in range(max_num_gt):
                    if truths[i][1] == 0:
                        return i

            my_camera = camera(
                *cam_params
            )  # cam_params = (K, dist_coefs, im_width, im_height, tf_cam_ego)
            for l in self.lines:
                # Get the image path
                num_img = len(self.lines)
                imgpath = l.rstrip()
                labpath = imgpath.replace('images', 'labels').replace(
                    'JPEGImages',
                    'labels').replace('.jpg', '.txt').replace('.png', '.txt')
                num_labels = 2 * num_keypoints + 3  # +2 for ground-truth of width/height , +1 for class label
                label = torch.zeros(max_num_gt * num_labels)
                if os.path.getsize(labpath):
                    ow = my_camera.im_w
                    oh = my_camera.im_h
                    tmp = torch.from_numpy(read_truths_args(labpath))
                    tmp = tmp.view(-1)
                    tsz = tmp.numel()
                    if tsz > max_num_gt * num_labels:
                        label = tmp[0:max_num_gt * num_labels]
                    elif tsz > 0:
                        label[0:tsz] = tmp
                label = label.cuda()
                truths = label.view(-1, num_labels)
                num_gts = truths_length(truths)
                for k in range(num_gts):
                    box_gt = list()
                    for j in range(1, 2 * num_keypoints + 1):
                        box_gt.append(truths[k][j])
                    box_gt.extend([1.0, 1.0])
                    box_gt.append(truths[k][0])
                corners2D_gt = np.array(np.reshape(box_gt[:18], [-1, 2]),
                                        dtype='float32')
                corners2D_gt[:, 0] = corners2D_gt[:, 0] * my_camera.im_w
                corners2D_gt[:, 1] = corners2D_gt[:, 1] * my_camera.im_h

                R_gt, t_gt = pnp(
                    np.array(np.transpose(
                        np.concatenate((np.zeros((3, 1)), corners3D[:3, :]),
                                       axis=1)),
                             dtype='float32'), corners2D_gt,
                    np.array(my_camera.K, dtype='float32'))
                MAX_DIST_TO_KEEP = 4
                if np.any(corners2D_gt < 0) or np.any(
                        corners2D_gt[:, 0] > my_camera.im_w) or np.any(
                            corners2D_gt[:, 1] > my_camera.im_h
                        ) or la.norm(t_gt) > MAX_DIST_TO_KEEP:
                    keep_list.append(False)
                else:
                    keep_list.append(True)

            self.lines = [
                l for (l, my_bool) in zip(self.lines, keep_list) if my_bool
            ]
            if len(self.lines) == 0:
                raise RuntimeError(
                    "ALL IMAGES SCREENED OUT (out of original {} images)".
                    format(num_img))
            print(
                'Keeping {} / {} images (throwing out images with no drone in FOV)'
                .format(len(self.lines), num_img))

        # Shuffle
        if shuffle:
            random.shuffle(self.lines)

        # Initialize variables
        self.nSamples = len(self.lines)
        self.transform = transform
        self.target_transform = target_transform
        self.train = train
        self.shape = shape
        self.seen = seen
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.bg_file_names = bg_file_names
        self.cell_size = cell_size
        self.nbatches = self.nSamples // self.batch_size
        self.num_keypoints = num_keypoints
        self.max_num_gt = max_num_gt  # maximum number of ground-truth labels an image can have
예제 #14
0
    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        imgpath = self.lines[index].rstrip()

        if self.train and index % 32== 0:
            if self.seen < 400*32:
               width = 13*32
               self.shape = (width, width)
            elif self.seen < 800*32:
               width = (random.randint(0,7) + 13)*32
               self.shape = (width, width)
            elif self.seen < 1200*32:
               width = (random.randint(0,9) + 12)*32
               self.shape = (width, width)
            elif self.seen < 1600*32:
               width = (random.randint(0,11) + 11)*32
               self.shape = (width, width)
            elif self.seen < 2000*32:
               width = (random.randint(0,13) + 10)*32
               self.shape = (width, width)
            elif self.seen < 2400*32:
               width = (random.randint(0,15) + 9)*32
               self.shape = (width, width)
            elif self.seen < 3000*32:
               width = (random.randint(0,17) + 8)*32
               self.shape = (width, width)
            else: # self.seen < 20000*64:
               width = (random.randint(0,19) + 7)*32
               self.shape = (width, width)
        if self.train:
            jitter = 0.2
            hue = 0.1
            saturation = 1.5 
            exposure = 1.5

            # Get background image path
            random_bg_index = random.randint(0, len(self.bg_file_names) - 1)
            bgpath = self.bg_file_names[random_bg_index]

            img, label = load_data_detection(imgpath, self.shape, jitter, hue, saturation, exposure, bgpath)
            label = torch.from_numpy(label)
        else:
            img = Image.open(imgpath).convert('RGB')
            if self.shape:
                img = img.resize(self.shape)
    
            labpath = imgpath.replace('images', 'labels').replace('JPEGImages', 'labels').replace('.jpg', '.txt').replace('.png','.txt')
            label = torch.zeros(50*21)
            if os.path.getsize(labpath):
                ow, oh = img.size
                tmp = torch.from_numpy(read_truths_args(labpath, 8.0/ow))
                tmp = tmp.view(-1)
                tsz = tmp.numel()
                if tsz > 50*21:
                    label = tmp[0:50*21]
                elif tsz > 0:
                    label[0:tsz] = tmp

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            label = self.target_transform(label)

        self.seen = self.seen + self.num_workers
        return (img, label)
예제 #15
0
    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        imgpath = self.lines[index].rstrip()

        # Depending upon how many images and batches it has seen resize images
        if self.train and index % 64 == 0:
            if self.seen < 4000 * 64:
                width = 13 * 32
                self.shape = (width, width)
            elif self.seen < 8000 * 64:
                width = (random.randint(0, 3) + 13) * 32
                self.shape = (width, width)
            elif self.seen < 12000 * 64:
                width = (random.randint(0, 5) + 12) * 32
                self.shape = (width, width)
            elif self.seen < 16000 * 64:
                width = (random.randint(0, 7) + 11) * 32
                self.shape = (width, width)
            else:  # self.seen < 20000*64:
                width = (random.randint(0, 9) + 10) * 32
                self.shape = (width, width)

        if self.train:
            # Distort images for augmentation only during training
            # Image distortion parameters
            jitter = 0.2
            hue = 0.1
            saturation = 1.5
            exposure = 1.5

            # read images and distort them and return distorted input and target values
            img, label = load_data_detection(imgpath, self.shape, jitter, hue,
                                             saturation, exposure)
            label = torch.from_numpy(
                label)  # convert numpy labels to torch tensor labels
        else:
            # Read image, convert to RGB and possibly resize
            img = Image.open(imgpath).convert('RGB')
            if self.shape:
                img = img.resize(self.shape)

            # Build path replacing strings
            labpath = imgpath.replace('images', 'labels').replace(
                'JPEGImages',
                'labels').replace('.jpg', '.txt').replace('.png', '.txt')
            label = torch.zeros(50 * 5)  # Empty labels
            #if os.path.getsize(labpath):
            #tmp = torch.from_numpy(np.loadtxt(labpath))
            try:
                tmp = torch.from_numpy(
                    read_truths_args(labpath,
                                     8.0 / img.width).astype('float32'))
            except Exception:
                tmp = torch.zeros(1, 5)
            #tmp = torch.from_numpy(read_truths(labpath))
            tmp = tmp.view(-1)
            tsz = tmp.numel(
            )  # returns the number of elements in the temporary field
            #print('labpath = %s , tsz = %d' % (labpath, tsz))
            if tsz > 50 * 5:
                label = tmp[0:50 * 5]
            elif tsz > 0:
                label[0:tsz] = tmp

        # Possibly do some transformations
        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            label = self.target_transform(label)

        self.seen = self.seen + self.num_workers
        return (img, label)
예제 #16
0
    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        line = self.lines[index].replace('\n', '').split(' ')
        #imgpath = self.lines[index].rstrip()
        imgpath = os.path.join(self.imgRoot, line[0])
        #print(imgpath)
        lab_valueset = np.roll(np.array(eval(','.join(line[1:])),
                                        dtype=np.float64).reshape(-1, 5),
                               1,
                               axis=1)  #np.roll(d, (1,0), axis=(1,1))
        #print(lab_valueset)
        if self.train:
            #if self.seen % (self.batch_size * 10) == 0: # in paper, every 10 batches, but we did every 64 images
            #   self.shape = self.get_different_scale()
            img, label = load_data_detection(imgpath, self.shape, self.crop,
                                             self.jitter, self.hue,
                                             self.saturation, self.exposure,
                                             lab_valueset)
            #print('cp2,label:    ',label)
            label = torch.from_numpy(label)

        else:
            img = Image.open(imgpath).convert('RGB')
            #print(imgpath)
            original_img_size = img.size
            #print(f'img.height: {img.height}')
            #print(f'img.width: {img.width}')
            #print(f'lab_valueset: {lab_valueset}')
            if self.shape:
                img, org_w, org_h = letterbox_image(
                    img, self.shape[0], self.shape[1]), img.width, img.height

            #labpath = imgpath.replace('images', 'labels').replace('JPEGImages', 'labels').replace('.jpg', '.txt').replace('.png','.txt')
            label = torch.zeros(50 * 5)
            #if os.path.getsize(labpath):
            #tmp = torch.from_numpy(np.loadtxt(labpath))
            try:
                #tmp = torch.from_numpy(read_truths_args(labpath, 8.0/img.width).astype('float32'))
                tmp = torch.from_numpy(
                    read_truths_args(lab_valueset, original_img_size,
                                     8.0 / img.width).astype('float32'))
                #print(tmp)
            except Exception:
                tmp = torch.zeros(1, 5)
            #tmp = torch.from_numpy(read_truths(labpath))
            tmp = tmp.view(-1)
            tsz = tmp.numel()
            #print('labpath = %s , tsz = %d' % (labpath, tsz))
            if tsz > 50 * 5:
                label = tmp[0:50 * 5]
            elif tsz > 0:
                label[0:tsz] = tmp

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            label = self.target_transform(label)

        self.seen = self.seen + self.num_workers
        #print('cp1',(img, label))
        #print('cp1',(img.shape, label.shape))
        if self.train:
            return (img, label)
        else:
            return (img, label, org_w, org_h)
예제 #17
0
    def __getitem__(self, index):
        # print('get item')
        assert index <= len(self), 'index range error'
        imgpath = self.lines[index].rstrip()

        img_id = os.path.basename(imgpath).split('.')[0]

        if self.train:
            # print(index)
            if (
                    self.seen % (self.batch_size * 100)
            ) == 0:  # in paper, every 10 batches, but we did every 64 images
                self.shape = self.get_different_scale_my()
                # self.shape = self.get_different_scale()
                # print('Image size: ', self.shape)
                # self.shape = self.get_different_scale()
            img, label = load_data_detection(imgpath, self.shape, self.crop,
                                             self.jitter, self.hue,
                                             self.saturation, self.exposure)
            label = torch.from_numpy(label)
        else:
            img = Image.open(imgpath).convert('RGB')
            if self.shape:
                img, org_w, org_h = letterbox_image(
                    img, self.shape[0], self.shape[1]), img.width, img.height

            # labpath = imgpath.replace('images', 'labels').replace('images', 'Annotations').replace('.jpg', '.txt').replace('.png','.txt')
            labpath = imgpath.replace('images', 'labels').replace(
                '.jpg', '.txt').replace('.jpeg', '.txt').replace(
                    '.png', '.txt').replace('.tif', '.txt')
            label = torch.zeros(50 * 5)
            #if os.path.getsize(labpath):
            #tmp = torch.from_numpy(np.loadtxt(labpath))
            try:
                tmp = torch.from_numpy(
                    read_truths_args(labpath,
                                     8.0 / img.width).astype('float32'))
            except Exception:
                tmp = torch.zeros(1, 5)
            #tmp = torch.from_numpy(read_truths(labpath))
            tmp = tmp.view(-1)
            tsz = tmp.numel()
            #print('labpath = %s , tsz = %d' % (labpath, tsz))
            if tsz > 50 * 5:
                label = tmp[0:50 * 5]
            elif tsz > 0:
                label[0:tsz] = tmp

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            label = self.target_transform(label)

        self.seen = self.seen + self.num_workers

        if self.train:
            if self.condition:
                #### this is for label daytime or nighttime on KAIST dataset
                set_label = 0 if int(img_id[4]) < 3 else 1
                return (img, (label, set_label))
            else:
                # print('end function get item')
                return (img, label)
        else:
            return (img, label, org_w, org_h)
예제 #18
0
    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        imgpath = self.lines[index].rstrip()

        # if self.train and index % 64== 0: # index是 ImageSets/Main/trainval.txt 里面行的索引
        #     # print("change input size at seen = %d"%self.seen)
        #     if self.seen < 18000*64: # VOC 2007 + voc2012 trainval=16K, 4000*64大概VOC trainval的5个epoch
        #        width = 19*32 # 800
        #        self.shape = (width, width)
        #     elif self.seen < 23000*64:
        #        width = (random.randint(0,3) + 19)*32
        #        self.shape = (width, width)
        #     elif self.seen < 28000*64:
        #        width = (random.randint(0,5) + 18)*32
        #        self.shape = (width, width)
        #     elif self.seen < 33000*64:
        #        width = (random.randint(0,7) + 17)*32
        #        self.shape = (width, width)
        #     else: # self.seen < 20000*64:
        #        width = (random.randint(0,9) + 16)*32 # 512~800
        #        self.shape = (width, width)
        #        print(self.shape)

        if self.train:
            jitter = 0.2
            hue = 0.1
            saturation = 1.5
            exposure = 1.5

            img, label = load_data_detection(imgpath, self.shape, jitter, hue,
                                             saturation, exposure)
            label = torch.from_numpy(label)
        else:
            img = Image.open(imgpath).convert('RGB')
            if self.shape:
                img = img.resize(self.shape)

            labpath = imgpath.replace('images', 'labels').replace(
                'JPEGImages',
                'labels').replace('.jpg', '.txt').replace('.png', '.txt')
            label = torch.zeros(50 * 5)
            #if os.path.getsize(labpath):
            #tmp = torch.from_numpy(np.loadtxt(labpath))
            try:
                tmp = torch.from_numpy(
                    read_truths_args(labpath,
                                     8.0 / img.width).astype('float32'))
            except Exception:
                tmp = torch.zeros(1, 5)
            #tmp = torch.from_numpy(read_truths(labpath))
            tmp = tmp.view(-1)
            tsz = tmp.numel()
            #print('labpath = %s , tsz = %d' % (labpath, tsz))
            if tsz > 50 * 5:
                label = tmp[0:50 * 5]
            elif tsz > 0:
                label[0:tsz] = tmp

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            label = self.target_transform(label)

        self.seen = self.seen + self.num_workers
        return (imgpath, img, label)
예제 #19
0
    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        imgpath = self.lines[index].rstrip()

        out_w, out_h = self.shape

        if self.train:
            jitter = 0.3
            hue = 0.1
            saturation = 1.5
            exposure = 1.5

            image = cv2.imread(imgpath)
            assert image is not None
            origin_height, origin_width = image.shape[:2]

            dw = origin_width * jitter
            dh = origin_height * jitter

            pleft = int(random.uniform(-dw, dw))
            pright = int(random.uniform(-dw, dw))
            ptop = int(random.uniform(-dh, dh))
            pbot = int(random.uniform(-dh, dh))

            swidth = origin_width - pleft - pright
            sheight = origin_height - ptop - pbot

            pleft2, pright2, ptop2, pbot2 = map(abs, [pleft, pright, ptop, pbot])

            image2 = cv2.copyMakeBorder(image, ptop2, pbot2, pleft2, pright2, cv2.BORDER_REPLICATE)
            croped = image2[ptop2+ptop:ptop2+ptop+sheight, pleft2+pleft:pleft2+pleft+swidth, :]

            sized, new_w, new_h, dx3, dy3 = letterbox_image(croped, out_w, out_h, return_dxdy=True)

            img = random_distort_image(sized, hue, saturation, exposure)

            flip = (random.random() < 0.5)
            if flip:
                img = cv2.flip(img, 1)

            image = img

            labpath = imgpath.replace('images', 'labels') \
                .replace('JPEGImages', 'labels') \
                .replace('.jpg', '.txt') \
                .replace('.png', '.txt')

            label = np.loadtxt(labpath)
            if label is None:
                label = np.full((5,), -1, dtype=np.float32)
            else:
                label2 = np.full((label.size//5, 5), -1, np.float32)
                bs = np.reshape(label, (-1, 5))
                cc = 0
                for i in range(bs.shape[0]):
                    x1 = bs[i][1] - bs[i][3] / 2
                    y1 = bs[i][2] - bs[i][4] / 2
                    x2 = bs[i][1] + bs[i][3] / 2
                    y2 = bs[i][2] + bs[i][4] / 2

                    x1 = min(swidth, max(0, x1 * origin_width - pleft))
                    y1 = min(sheight, max(0, y1 * origin_height - ptop))
                    x2 = min(swidth, max(0, x2 * origin_width - pleft))
                    y2 = min(sheight, max(0, y2 * origin_height - ptop))

                    x1 = (x1 / swidth * new_w + dx3) / out_w
                    y1 = (y1 / sheight * new_h + dy3) / out_h
                    x2 = (x2 / swidth * new_w + dx3) / out_w
                    y2 = (y2 / sheight * new_h + dy3) / out_h

                    bs[i][1] = (x1 + x2) / 2
                    bs[i][2] = (y1 + y2) / 2
                    bs[i][3] = (x2 - x1)
                    bs[i][4] = (y2 - y1)

                    if flip:
                        bs[i][1] = 0.999 - bs[i][1]

                    if bs[i][3] < 0.001 or bs[i][4] < 0.001:
                        continue
                    label2[cc] = bs[i]
                    cc += 1
                    if cc >= 50:
                        break
                label = label2[:cc].flatten()
        else:
            image = cv2.imread(imgpath)
            assert image is not None
            sized, new_w, new_h, dx, dy = letterbox_image(image, out_w, out_h, return_dxdy=True)

            labpath = imgpath.replace('images', 'labels') \
                .replace('JPEGImages', 'labels') \
                .replace('.jpg', '.txt').replace('.png', '.txt')

            tmp = read_truths_args(labpath, 8.0 / image.shape[1]).astype(np.float32)
            tmp[:, 1:] = (tmp[:, 1:] * np.array([new_w, new_h, new_w, new_h]) + np.array([dx, dy, 0, 0])) / np.array(
                [out_w, out_h, out_w, out_h])

            label = tmp.flatten()
            image = sized

        return dict(image=image, label=label)
예제 #20
0
    def __getitem__(self, index):

        # Ensure the index is smallet than the number of samples in the dataset, otherwise return error
        assert index <= len(self), 'index range error'

        # Get the image path
        imgindex = self.lines[index].rstrip()
        # print('imgindex', imgindex)
        imgpath = os.path.join(self.dataDir, 'rgb',
                               str(imgindex) + self.rgbfileType)
        if not self.train:
            print('imgpath', imgpath, end='\r')

        # Decide which size you are going to resize the image depending on the epoch (10, 20, etc.)
        if self.train and index % self.batch_size == 0:
            if self.seen < 10 * self.nbatches * self.batch_size:
                width = 13 * self.cell_size
                self.shape = (width, width)
            elif self.seen < 20 * self.nbatches * self.batch_size:
                width = (random.randint(0, 7) + 13) * self.cell_size
                self.shape = (width, width)
            elif self.seen < 30 * self.nbatches * self.batch_size:
                width = (random.randint(0, 9) + 12) * self.cell_size
                self.shape = (width, width)
            elif self.seen < 40 * self.nbatches * self.batch_size:
                width = (random.randint(0, 11) + 11) * self.cell_size
                self.shape = (width, width)
            elif self.seen < 50 * self.nbatches * self.batch_size:
                width = (random.randint(0, 13) + 10) * self.cell_size
                self.shape = (width, width)
            elif self.seen < 60 * self.nbatches * self.batch_size:
                width = (random.randint(0, 15) + 9) * self.cell_size
                self.shape = (width, width)
            elif self.seen < 70 * self.nbatches * self.batch_size:
                width = (random.randint(0, 17) + 8) * self.cell_size
                self.shape = (width, width)
            else:
                width = (random.randint(0, 19) + 7) * self.cell_size
                self.shape = (width, width)

        if self.train:
            # Decide on how much data augmentation you are going to apply
            jitter = 0.2
            hue = 0.1
            saturation = 1.5
            exposure = 1.5

            # Get background image path
            random_bg_index = random.randint(0, len(self.bg_file_names) - 1)
            bgpath = self.bg_file_names[random_bg_index]

            # Get the data augmented image and their corresponding labels
            img, label = load_data_detection(imgpath, self.shape, jitter, hue,
                                             saturation, exposure, bgpath,
                                             self.num_keypoints,
                                             self.max_num_gt)

            # Convert the labels to PyTorch variables
            label = torch.from_numpy(label)

        else:
            # Get the validation image, resize it to the network input size
            img = Image.open(imgpath).convert('RGB')
            if self.shape:
                img = img.resize(self.shape)

            # Read the validation labels, allow upto 50 ground-truth objects in an image
            labpath = imgpath.replace('rgb', 'labels').replace(
                self.rgbfileType, '.txt')
            # print('testing', labpath)

            num_labels = 2 * self.num_keypoints + 3  # +2 for ground-truth of width/height , +1 for class label
            label = torch.zeros(self.max_num_gt * num_labels)
            if os.path.getsize(labpath):
                ow, oh = img.size
                tmp = torch.from_numpy(read_truths_args(labpath))
                tmp = tmp.view(-1)
                tsz = tmp.numel()
                if tsz > self.max_num_gt * num_labels:
                    label = tmp[0:self.max_num_gt * num_labels]
                elif tsz > 0:
                    label[0:tsz] = tmp

        # Tranform the image data to PyTorch tensors
        if self.transform is not None:
            img = self.transform(img)

        # If there is any PyTorch-specific transformation, transform the label data
        if self.target_transform is not None:
            label = self.target_transform(label)

        # Increase the number of seen examples
        self.seen = self.seen + self.num_workers

        # Return the retrieved image and its corresponding label
        return (img, label)
예제 #21
0
    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        imgpath = self.lines[index].rstrip()
        ''' Fix the width to be 13*32=416 and do not randomize
        '''
        width = 13 * 32

        if self.train and index % 64 == 0:
            if self.seen < 4000 * 64:
                width = 13 * 32
                self.shape = (width, width)
            elif self.seen < 8000 * 64:
                width = (random.randint(0, 3) + 13) * 32
                self.shape = (width, width)
            elif self.seen < 12000 * 64:
                width = (random.randint(0, 5) + 12) * 32
                self.shape = (width, width)
            elif self.seen < 16000 * 64:
                width = (random.randint(0, 7) + 11) * 32
                self.shape = (width, width)
            else:  # self.seen < 20000*64:
                width = (random.randint(0, 9) + 10) * 32
                self.shape = (width, width)

        if self.train:
            jitter = 0.2
            hue = 0.1
            saturation = 1.5
            exposure = 1.5

            img, label = load_data_detection(imgpath, self.shape, jitter, hue,
                                             saturation, exposure)
            label = torch.from_numpy(label)
        else:
            img = Image.open(imgpath).convert('RGB')
            if self.shape:
                img = img.resize(self.shape)

            labpath = imgpath.replace('images', 'labels').replace(
                'JPEGImages',
                'labels').replace('.jpg', '.txt').replace('.png', '.txt')
            # # for KITTI
            # labpath = imgpath.replace('images', 'labels').replace('PNGImages_cropped', 'labels_cropped_car_person').replace('.jpg', '.txt').replace('.png','.txt')
            #labpath = imgpath.replace('images', 'labels').replace('train', 'labels').replace('.jpg', '.txt').replace('.png','.txt')
            label = torch.zeros(50 * 5)
            try:
                tmp = torch.from_numpy(
                    read_truths_args(labpath,
                                     8.0 / img.width).astype('float32'))
            except Exception:
                tmp = torch.zeros(1, 5)
            #tmp = torch.from_numpy(read_truths(labpath))

            # # for KITTI
            # if tmp.size() == 0:
            #     tmp = torch.zeros(1,5)
            tmp = tmp.view(-1)
            tsz = tmp.numel()
            # print('labpath = %s , tsz = %d' % (labpath, tsz))
            if tsz > 50 * 5:
                label = tmp[0:50 * 5]
            elif tsz > 0:
                label[0:tsz] = tmp

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            label = self.target_transform(label)

        self.seen = self.seen + self.num_workers
        return (img, label, imgpath)
예제 #22
0
    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        imgpath = self.lines[index].rstrip()

        if self.train and index % 64 == 0:
            if self.seen < 4000 * 64:
                width = 13 * 32
                self.shape = (width, width)
            elif self.seen < 8000 * 64:
                width = (random.randint(0, 3) + 13) * 32
                self.shape = (width, width)
            elif self.seen < 12000 * 64:
                width = (random.randint(0, 5) + 12) * 32
                self.shape = (width, width)
            elif self.seen < 16000 * 64:
                width = (random.randint(0, 7) + 11) * 32
                self.shape = (width, width)
            else:  # self.seen < 20000*64:
                width = (random.randint(0, 9) + 10) * 32
                self.shape = (width, width)

        if self.train:
            jitter = 0.1
            hue = 0.1
            saturation = 1.5
            exposure = 1.5
            # GET THE DATA AUGMENTED IMAGE AND A VECTOR LABEL WITH GROUND TRUTH
            #
            img, label, phoc_matrix = load_data_detection(
                imgpath, self.shape, jitter, hue, saturation, exposure)
            label = torch.from_numpy(label)
            phoc_matrix = torch.from_numpy(phoc_matrix)

        else:
            img = Image.open(imgpath).convert('RGB')
            if self.shape:
                img = img.resize(self.shape)

            labpath = imgpath.replace('images', 'labels').replace(
                'JPEGImages',
                'labels').replace('.jpg', '.txt').replace('.png', '.txt')
            #label = torch.zeros(50*5)
            label = torch.zeros(100 * 5)
            '''try:
                tmp = torch.from_numpy(read_truths_args(labpath, 8.0/img.width).astype('float32'))

            except Exception:
                print("Exception on Validation has ocurred!")
                tmp = torch.zeros(1,5)
            '''

            tmp, word_list = read_truths_args(labpath, 8.0 / img.width)
            tmp = torch.from_numpy(tmp).type(torch.FloatTensor)

            tmp = tmp.view(-1)
            tsz = tmp.numel()
            #print('labpath = %s , tsz = %d' % (labpath, tsz))
            if tsz > 100 * 5:
                label = tmp[0:100 * 5]
            elif tsz > 0:
                label[0:tsz] = tmp

            phoc_matrix = fill_phoc(word_list)
            phoc_matrix = torch.from_numpy(phoc_matrix)

        # TRANSFORM IMAGE TO TORCH TENSOR
        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            label = self.target_transform(label)

        self.seen = self.seen + self.num_workers
        return (img, label, phoc_matrix)
예제 #23
0
    def __getitem__(self, index):

        # Ensure the index is smallet than the number of samples in the dataset, otherwise return error
        assert index <= len(self), 'index range error'

        # Get the image path
        imgpath = self.lines[index].rstrip()

        # Decide which size you are going to resize the image depending on the iteration
        if self.train and index % self.batch_size == 0:
            if self.seen < 400 * self.batch_size:
                width = 13 * self.cell_size
                self.shape = (width, width)
            elif self.seen < 800 * self.batch_size:
                width = (random.randint(0, 7) + 13) * self.cell_size
                self.shape = (width, width)
            elif self.seen < 1200 * self.batch_size:
                width = (random.randint(0, 9) + 12) * self.cell_size
                self.shape = (width, width)
            elif self.seen < 1600 * self.batch_size:
                width = (random.randint(0, 11) + 11) * self.cell_size
                self.shape = (width, width)
            elif self.seen < 2000 * self.batch_size:
                width = (random.randint(0, 13) + 10) * self.cell_size
                self.shape = (width, width)
            elif self.seen < 2400 * self.batch_size:
                width = (random.randint(0, 15) + 9) * self.cell_size
                self.shape = (width, width)
            elif self.seen < 3000 * self.batch_size:
                width = (random.randint(0, 17) + 8) * self.cell_size
                self.shape = (width, width)
            else:
                width = (random.randint(0, 19) + 7) * self.cell_size
                self.shape = (width, width)

        if self.train:
            # If you are going to train, decide on how much data augmentation you are going to apply
            jitter = 0.2
            hue = 0.1
            saturation = 1.5
            exposure = 1.5

            # Get background image path
            #random_bg_index = random.randint(0, len(self.bg_file_names) - 1)
            #bgpath = self.bg_file_names[random_bg_index]

            # Get the data augmented image and their corresponding labels
            img, label, label_1, label_2 = load_data_detection(
                imgpath, self.shape, jitter, hue, saturation, exposure)

            # Convert the labels to PyTorch variables
            label = torch.from_numpy(label)
            label_1 = torch.from_numpy(label_1).float()
            label_2 = torch.from_numpy(label_2).float()

        else:
            # Get the validation image, resize it to the network input size
            img = Image.open(imgpath).convert('RGB')
            if self.shape:
                img = img.resize(self.shape)

            # Read the validation labels, allow upto 50 ground-truth objects in an image
            labpath = imgpath.replace('images', 'labels').replace(
                'JPEGImages',
                'labels').replace('.jpeg', '.txt').replace('.png', '.txt')
            lab_pose_path = imgpath.replace('images', 'labels_pose').replace(
                'JPEGImages',
                'labels_pose').replace('.jpeg',
                                       '.txt').replace('.jpg', '.txt')

            label = torch.zeros(50 * 21)
            label_1 = torch.zeros(1 * 3)
            label_2 = torch.zeros(1 * 4)
            if os.path.getsize(labpath):
                ow, oh = img.size
                tmp = torch.from_numpy(read_truths_args(labpath))
                tmp = tmp.view(-1)
                tsz = tmp.numel()
                if tsz > 50 * 21:
                    label = tmp[0:50 * 21]
                elif tsz > 0:
                    label[0:tsz] = tmp
            if os.path.getsize(lab_pose_path):
                tmp_1, tmp_2 = read_truths_pose(lab_pose_path)
                tmp_1 = torch.from_numpy(tmp_1)
                tmp_2 = torch.from_numpy(tmp_2)
                tmp_1 = tmp_1.view(-1)
                tsz_1 = tmp_1.numel()
                tmp_2 = tmp_2.view(-1)
                tsz_2 = tmp_2.numel()
                if tsz_1 > 1 * 7:
                    label_1 = tmp_1[0:1 * 7]
                elif tsz_1 > 0:
                    label_1[0:tsz_1] = tmp_1
                if tsz_2 > 1 * 7:
                    label_2 = tmp_2[0:1 * 7]
                elif tsz_2 > 0:
                    label_2[0:tsz_2] = tmp_2

        # Tranform the image data to PyTorch tensors
        if self.transform is not None:
            img = self.transform(img)

        # If there is any PyTorch-specific transformation, transform the label data
        if self.target_transform is not None:
            label = self.target_transform(label)

        # Increase the number of seen examples
        self.seen = self.seen + self.num_workers

        # Return the retrieved image and its corresponding label
        return (img, label, label_1, label_2)
예제 #24
0
    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        imgpath = self.lines[index].rstrip()

        if self.train and index % 64 == 0:
            if self.seen < 4000 * 64:
                width = 13 * 32
                self.shape = (width, width)
            elif self.seen < 8000 * 64:
                width = (random.randint(0, 3) + 13) * 32
                self.shape = (width, width)
            elif self.seen < 12000 * 64:
                width = (random.randint(0, 5) + 12) * 32
                self.shape = (width, width)
            elif self.seen < 16000 * 64:
                width = (random.randint(0, 7) + 11) * 32
                self.shape = (width, width)
            else:  # self.seen < 20000*64:
                width = (random.randint(0, 9) + 10) * 32
                self.shape = (width, width)

        if self.train:
            # jitter = 0.2
            jitter = 0.0
            hue = 0.014  # was 0.05
            saturation = 1.1  # was 1.5
            exposure = 1.2  # was 1.5

            # Get background image path
            random_bg_index = random.randint(0, len(self.bg_file_names) - 1)
            bgpath = self.bg_file_names[random_bg_index]

            img, label = load_data_detection(imgpath, self.shape, jitter, hue,
                                             saturation, exposure, bgpath)

            if debug_multi:
                np_img = np.array(img)
                np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)
                print(imgpath)
                print(label)
                print('length of label: ' + str(len(label)))
                # add here
                # replace the object after /FRC2019/ with the foreground object class name
                mod_imgpath = imgpath.replace(
                    '../FRC2019/brownGlyph/JPEGImages/',
                    './test_load_data_detection/').replace('jpg', 'png')
                mod_labpath = imgpath.replace(
                    '../FRC2019/brownGlyph/JPEGImages/',
                    './test_load_data_detection/').replace('jpg', 'txt')
                print(mod_imgpath)
                print(mod_labpath)
                cv2.imwrite(mod_imgpath, np_img)
                np.savetxt(mod_labpath, label)

            label = torch.from_numpy(label)
        else:
            img = Image.open(imgpath).convert('RGB')
            if self.shape:
                img = img.resize(self.shape)
            # add here
            # replace the object after /FRC2019/ with the foreground object class name
            labpath = imgpath.replace('brownGlyph', self.objclass).replace(
                'images',
                'labels_occlusion').replace('JPEGImages',
                                            'labels_occlusion').replace(
                                                '.jpg', '.txt').replace(
                                                    '.png', '.txt')

            label = torch.zeros(50 * 21)
            if os.path.getsize(labpath):
                ow, oh = img.size
                #tmp = torch.from_numpy(read_truths_args(labpath, 8.0/ow))
                tmp = torch.from_numpy(read_truths_args(labpath))
                tmp = tmp.view(-1)
                tsz = tmp.numel()
                if tsz > 50 * 21:
                    label = tmp[0:50 * 21]
                elif tsz > 0:
                    label[0:tsz] = tmp

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            label = self.target_transform(label)

        self.seen = self.seen + self.num_workers
        return (img, label)