예제 #1
0
파일: MyDataset.py 프로젝트: wymGAKKI/saps
    def __getitem__(self, index):
        normal_path, img_list, dirs, reflectance_path,  mask_path = self._getInputPath(index)
        normal = sio.loadmat(normal_path)['normal'].astype(np.float32)
        reflectance = imread(reflectance_path).astype(np.float32) / 255.0
        if(reflectance.shape[2] == 4):
                reflectance = reflectance[:,:,:3]
        imgs   =  []
        for i in img_list:
            img = imread(i).astype(np.float32) / 255.0
            if(img.shape[2] == 4):
                img = img[:,:,:3]
            imgs.append(img)
        img = np.concatenate(imgs, 2)

        h, w, c = img.shape
        crop_h, crop_w = self.args.crop_h, self.args.crop_w
        if self.args.rescale and not (crop_h == h):
            sc_h = np.random.randint(crop_h, h) if self.args.rand_sc else self.args.scale_h
            sc_w = np.random.randint(crop_w, w) if self.args.rand_sc else self.args.scale_w
            img, normal = pms_transforms.rescale(img, normal, [sc_h, sc_w])

        if self.args.crop:
            img, normal = pms_transforms.randomCrop(img, normal, [crop_h, crop_w])

        # if self.args.color_aug:
        #     img = img * np.random.uniform(1, self.args.color_ratio)

        if self.args.int_aug:
            ints = pms_transforms.getIntensity(len(imgs))
            img  = np.dot(img, np.diag(ints.reshape(-1)))
        else:
            ints = np.ones(c)

        # if self.args.noise_aug:
        #     img = pms_transforms.randomNoiseAug(img, self.args.noise)

        mask = sio.loadmat(mask_path)['mask'].astype(np.float32)
        norm   = np.sqrt((normal * normal).sum(2, keepdims=True))
        normal = normal / (norm + 1e-10) # Rescale normal to unit length

        item = {'normal': normal, 'img': img, 'reflectance': reflectance}
        for k in item.keys(): 
            item[k] = pms_transforms.arrayToTensor(item[k])
        #item['dirs'] = torch.from_numpy(dirs).view(-1, 1, 1).float()
        item['lights'] = torch.from_numpy(dirs).view(-1, 1, 1).float()
        item['ints'] = torch.from_numpy(ints).view(-1, 1, 1).float()
        item['mask'] = torch.from_numpy(mask).unsqueeze(0)
        
        # normal : torch.Size([3, 128, 128])
        # img : torch.Size([6, 128, 128])
        # mask : torch.Size([1, 128, 128])
        # dirs : torch.Size([6, 1, 1])
        # ints : torch.Size([6, 1, 1])

        return item
예제 #2
0
    def __getitem__(self, index):
        normal_path, img_list, lights = self._getInputPath(index)
        normal = imread(normal_path).astype(np.float32) / 255.0 * 2 - 1
        imgs = []
        shadows = []
        for i in img_list:
            img = imread(i).astype(np.float32) / 255.0
            if self.args.shadow:
                s = imread(i.replace('/l_', '/s_')).astype(np.float32) / 255.0
                shadows.append(s[:, :, 0:1])
            imgs.append(img)

        img = np.concatenate(imgs, 2)
        shadows = np.concatenate(shadows, 2)
        img = np.concatenate([img, shadows], 2)

        h, w, c = img.shape
        crop_h, crop_w = self.args.crop_h, self.args.crop_w
        if self.args.rescale:
            sc_h = np.random.randint(crop_h, h)
            sc_w = np.random.randint(crop_w, w)
            img, normal = pms_transforms.rescale(img, normal, [sc_h, sc_w])

        if self.args.crop:
            img, normal = pms_transforms.randomCrop(img, normal,
                                                    [crop_h, crop_w])
        if self.args.shadow:
            shadow = np.empty_like(img[:, :, 96:])
            shadow[:] = img[:, :, 96:]
        if self.args.color_aug:
            img = (img * np.random.uniform(1, 3)).clip(0, 2)

        if self.args.noise_aug:
            img = pms_transforms.randomNoiseAug(img, self.args.noise)

        mask = pms_transforms.normalToMask(normal)
        normal = normal * mask.repeat(3, 2)
        item = {'N': normal, 'img': img, 'mask': mask}
        for k in item.keys():
            item[k] = pms_transforms.arrayToTensor(item[k])

        if self.args.in_light:
            item['light'] = torch.from_numpy(lights).view(-1, 1, 1).float()
        if self.args.shadow:
            item['shadow'] = shadow

        return item
예제 #3
0
    def __getitem__(self, index):
        normal_path, img_list, dirs = self._getInputPath(index)
        normal = imread(normal_path).astype(np.float32) / 255.0 * 2 - 1
        imgs = []
        for i in img_list:
            img = imread(i).astype(np.float32) / 255.0
            imgs.append(img)
        img = np.concatenate(imgs, 2)

        h, w, c = img.shape
        crop_h, crop_w = self.args.crop_h, self.args.crop_w
        if self.args.rescale and not (crop_h == h):
            sc_h = np.random.randint(
                crop_h, h) if self.args.rand_sc else self.args.scale_h
            sc_w = np.random.randint(
                crop_w, w) if self.args.rand_sc else self.args.scale_w
            img, normal = pms_transforms.rescale(img, normal, [sc_h, sc_w])

        if self.args.crop:
            img, normal = pms_transforms.randomCrop(img, normal,
                                                    [crop_h, crop_w])

        if self.args.color_aug:
            img = img * np.random.uniform(1, self.args.color_ratio)

        if self.args.int_aug:
            ints = pms_transforms.getIntensity(len(imgs))
            img = np.dot(img, np.diag(ints.reshape(-1)))
        else:
            ints = np.ones(c)

        if self.args.noise_aug:
            img = pms_transforms.randomNoiseAug(img, self.args.noise)

        mask = pms_transforms.normalToMask(normal)
        normal = normal * mask.repeat(3, 2)
        norm = np.sqrt((normal * normal).sum(2, keepdims=True))
        normal = normal / (norm + 1e-10)  # Rescale normal to unit length

        item = {'normal': normal, 'img': img, 'mask': mask}
        for k in item.keys():
            item[k] = pms_transforms.arrayToTensor(item[k])

        item['dirs'] = torch.from_numpy(dirs).view(-1, 1, 1).float()
        item['ints'] = torch.from_numpy(ints).view(-1, 1, 1).float()
        return item
예제 #4
0
    def __getitem__(self, index):
        normal_path, img_list, lights = self._getInputPath(index)
        normal = imread(normal_path).astype(np.float32) / 255.0 * 2 - 1
        imgs = []
        for i in img_list:
            img = imread(i).astype(np.float32) / 255.0
            imgs.append(img)
        img = np.concatenate(imgs, 2)

        h, w, c = img.shape
        crop_h, crop_w = self.args.crop_h, self.args.crop_w
        if self.args.rescale:
            sc_h = np.random.randint(crop_h, h)
            sc_w = np.random.randint(crop_w, w)
            img, normal = pms_transforms.rescale(img, normal, [sc_h, sc_w])

        if self.args.crop:
            img, normal = pms_transforms.randomCrop(img, normal,
                                                    [crop_h, crop_w])

        if self.args.color_aug and not self.args.normalize:
            img = (img * np.random.uniform(1, 3)).clip(0, 2)

        if self.args.normalize:
            imgs = np.split(img, img.shape[2] // 3, 2)
            imgs = pms_transforms.normalize(imgs)
            img = np.concatenate(imgs, 2)

        if self.args.noise_aug:
            img = pms_transforms.randomNoiseAug(img, self.args.noise)

        mask = pms_transforms.normalToMask(normal)
        normal = normal * mask.repeat(3, 2)
        norm = np.sqrt((normal * normal).sum(2, keepdims=True))
        normal = normal / (norm + 1e-10)

        item = {'N': normal, 'img': img, 'mask': mask}
        for k in item.keys():
            item[k] = pms_transforms.arrayToTensor(item[k])

        if self.args.in_light:
            item['light'] = torch.from_numpy(lights).view(-1, 1, 1).float()

        return item