def process_images(self, raw, clean):
        i, j, h, w = RandomResizedCrop.get_params(raw,
                                                  scale=(0.5, 2.0),
                                                  ratio=(3. / 4., 4. / 3.))
        raw_img = resized_crop(raw,
                               i,
                               j,
                               h,
                               w,
                               size=self.img_size,
                               interpolation=Image.BICUBIC)
        clean_img = resized_crop(clean,
                                 i,
                                 j,
                                 h,
                                 w,
                                 self.img_size,
                                 interpolation=Image.BICUBIC)

        # get mask before further image augment
        mask = self.get_mask(raw_img, clean_img)
        mask_t = to_tensor(mask)

        binary_mask = (1 - mask_t)
        binary_mask = binary_mask.expand(3, -1, -1)
        clean_img = self.transformer(clean_img)
        corrupted_img = clean_img * binary_mask
        return corrupted_img, binary_mask, clean_img
Ejemplo n.º 2
0
    def process_images(self, raw, clean):
        i, j, h, w = RandomResizedCrop.get_params(raw,
                                                  scale=(0.5, 2.0),
                                                  ratio=(3. / 4., 4. / 3.))
        raw_img = resized_crop(raw,
                               i,
                               j,
                               h,
                               w,
                               size=self.img_size,
                               interpolation=Image.BICUBIC)
        clean_img = resized_crop(clean,
                                 i,
                                 j,
                                 h,
                                 w,
                                 self.img_size,
                                 interpolation=Image.BICUBIC)

        # get mask before further image augment
        mask = self.get_mask(raw_img, clean_img)
        mask_t = to_tensor(mask)
        mask_t = (mask_t > 0).float()
        mask_t = torch.nn.functional.max_pool2d(mask_t,
                                                kernel_size=5,
                                                stride=1,
                                                padding=2)
        # mask_t = mask_t.byte()

        raw_img = ImageChops.difference(mask, clean_img)
        return self.transformer(raw_img), 1 - mask_t, self.transformer(
            clean_img)
Ejemplo n.º 3
0
    def process_images(self, raw, clean):
        i, j, h, w = RandomResizedCrop.get_params(raw, scale=(0.1, 2), ratio=(3. / 4., 4. / 3.))
        raw_img = resized_crop(raw, i, j, h, w, size=self.img_size, interpolation=Image.BICUBIC)
        raw_img = self.transformer(raw_img)
        # raw_img = np.array(raw_img)

        mask_img = resized_crop(clean, i, j, h, w, self.img_size, interpolation=Image.BICUBIC)
        # mask_img = np.array(mask_img)
        return to_tensor(raw_img), to_tensor(mask_img)
Ejemplo n.º 4
0
    def process_images(self, raw, clean):
        i, j, h, w = RandomResizedCrop.get_params(raw, scale=(0.5, 2.0), ratio=(3. / 4., 4. / 3.))
        raw_img = resized_crop(raw, i, j, h, w, size=self.img_size, interpolation=Image.BICUBIC)
        clean_img = resized_crop(clean, i, j, h, w, self.img_size, interpolation=Image.BICUBIC)

        # get mask before further image augment
        mask_tensor = self.get_mask(raw_img, clean_img)

        raw_img = self.transformer(raw_img)
        return raw_img, mask_tensor
Ejemplo n.º 5
0
class InitialObservationNumpyJitteringDataset(data.Dataset):
    def __init__(self, data, info=None):
        assert data['observations'].dtype == np.uint8

        self.size = data['observations'].shape[0]
        self.traj_length = data['observations'].shape[1]
        self.data = data
        self.info = info

        self.jitter = ColorJitter((0.5, 1.5), (0.9, 1.1), (0.9, 1.1),
                                  (-0.1, 0.1))
        self.crop = RandomResizedCrop((48, 48), (0.9, 0.9), (1, 1))
        # RandomResizedCrop((int(sqrt(self.imlength)), int(sqrt(self.imlength))), (0.9, 0.9), (1, 1))

        if 'env' not in self.data:
            self.data['env'] = self.data['observations'][:, 0, :]

    def __len__(self):
        return self.size * self.traj_length

    def __getitem__(self, idx):
        traj_i = idx // self.traj_length
        trans_i = idx % self.traj_length

        x = Image.fromarray(self.data['observations'][traj_i, trans_i].reshape(
            48, 48, 3),
                            mode='RGB')
        c = Image.fromarray(self.data['env'][traj_i].reshape(48, 48, 3),
                            mode='RGB')

        # upsampling gives bad images so random resizing params set to 1 for now
        # crop = self.crop.get_params(c, (0.9, 0.9), (1, 1))
        crop = self.crop.get_params(c, (1, 1), (1, 1))

        # jitter = self.jitter.get_params((0.5,1.5), (0.9,1.1), (0.9,1.1), (-0.1,0.1))
        jitter = self.jitter.get_params(0.5, 0.1, 0.1, 0.1)

        x = jitter(
            F.resized_crop(x, crop[0], crop[1], crop[2], crop[3], (48, 48),
                           Image.BICUBIC))
        c = jitter(
            F.resized_crop(c, crop[0], crop[1], crop[2], crop[3], (48, 48),
                           Image.BICUBIC))
        x_t = normalize_image(np.array(x).flatten()).squeeze()
        env = normalize_image(np.array(c).flatten()).squeeze()

        data_dict = {
            'x_t': x_t,
            'env': env,
        }
        return data_dict
Ejemplo n.º 6
0
    def process_images(self, raw, clean):
        i, j, h, w = RandomResizedCrop.get_params(raw, scale=(0.5, 2.0), ratio=(3. / 4., 4. / 3.))
        raw_img = resized_crop(raw, i, j, h, w, size=self.img_size, interpolation=Image.BICUBIC)
        if self.add_random_masks:
            raw_img = self.random_mask.draw(raw_img)

        clean_img = resized_crop(clean, i, j, h, w, self.img_size, interpolation=Image.BICUBIC)

        # get mask before further image augment
        mask = self.get_mask(raw_img, clean_img)
        mask_t = to_tensor(mask)
        mask_t = (mask_t > brightness_difference).float()

        mask_t = torch.max(mask_t, dim=0, keepdim=True)
        mask_t = torch.nn.functional.max_pool2d(mask_t, kernel_size=9, stride=1, padding=4)

        # corrupt the clean images rather than using the raw ones 
        binary_mask = (1 - mask_t)  # valid positions are 1; holes are 0
        binary_mask = binary_mask.expand(3, -1, -1)
        clean_img = self.transformer(clean_img)
        corrupted_img = clean_img * binary_mask
        return corrupted_img, binary_mask, clean_img
    def process_images(self, clean, mask):
        i, j, h, w = RandomResizedCrop.get_params(clean,
                                                  scale=(0.5, 2.0),
                                                  ratio=(3. / 4., 4. / 3.))
        clean_img = resized_crop(clean,
                                 i,
                                 j,
                                 h,
                                 w,
                                 size=self.img_size,
                                 interpolation=Image.BICUBIC)
        mask = resized_crop(mask,
                            i,
                            j,
                            h,
                            w,
                            self.img_size,
                            interpolation=Image.BICUBIC)

        # get mask before further image augment
        # mask = self.get_mask(raw_img, clean_img)

        if self.add_random_masks:
            mask = random_masks(mask.copy(), size=self.img_size[0], offset=10)
        mask = np.where(
            np.array(mask) > brightness_difference * 255, np.uint8(255),
            np.uint8(0))
        mask = cv2.dilate(mask, np.ones((10, 10), np.uint8), iterations=1)

        mask = np.expand_dims(mask, -1)
        mask_t = to_tensor(mask)
        # mask_t = (mask_t > brightness_difference).float()

        # mask_t, _ = torch.max(mask_t, dim=0, keepdim=True)
        binary_mask = (1 - mask_t)  # valid positions are 1; holes are 0
        binary_mask = binary_mask.expand(3, -1, -1)
        clean_img = self.transformer(clean_img)
        corrupted_img = clean_img * binary_mask
        return corrupted_img, binary_mask, clean_img
Ejemplo n.º 8
0
def random_resize_crop(augment_targets,
                       scale,
                       ratio,
                       size,
                       threshold,
                       pre_crop_area=None):
    image, region_score, affinity_score, confidence_mask = augment_targets

    image = Image.fromarray(image)
    region_score = Image.fromarray(region_score)
    affinity_score = Image.fromarray(affinity_score)
    confidence_mask = Image.fromarray(confidence_mask)

    if pre_crop_area != None:
        i, j, h, w = pre_crop_area

    else:
        if random.random() < threshold:
            i, j, h, w = RandomResizedCrop.get_params(image,
                                                      scale=scale,
                                                      ratio=ratio)
        else:
            i, j, h, w = RandomResizedCrop.get_params(image,
                                                      scale=(1.0, 1.0),
                                                      ratio=(1.0, 1.0))

    image = resized_crop(image,
                         i,
                         j,
                         h,
                         w,
                         size=(size, size),
                         interpolation=InterpolationMode.BICUBIC)
    region_score = resized_crop(region_score,
                                i,
                                j,
                                h,
                                w, (size, size),
                                interpolation=InterpolationMode.BICUBIC)
    affinity_score = resized_crop(
        affinity_score,
        i,
        j,
        h,
        w,
        (size, size),
        interpolation=InterpolationMode.BICUBIC,
    )
    confidence_mask = resized_crop(
        confidence_mask,
        i,
        j,
        h,
        w,
        (size, size),
        interpolation=InterpolationMode.NEAREST,
    )

    image = np.array(image)
    region_score = np.array(region_score)
    affinity_score = np.array(affinity_score)
    confidence_mask = np.array(confidence_mask)
    augment_targets = [image, region_score, affinity_score, confidence_mask]

    return augment_targets