def process_images(self, raw, clean): i, j, h, w = RandomResizedCrop.get_params(raw, scale=(0.5, 2.0), ratio=(3. / 4., 4. / 3.)) raw_img = resized_crop(raw, i, j, h, w, size=self.img_size, interpolation=Image.BICUBIC) clean_img = resized_crop(clean, i, j, h, w, self.img_size, interpolation=Image.BICUBIC) # get mask before further image augment mask = self.get_mask(raw_img, clean_img) mask_t = to_tensor(mask) binary_mask = (1 - mask_t) binary_mask = binary_mask.expand(3, -1, -1) clean_img = self.transformer(clean_img) corrupted_img = clean_img * binary_mask return corrupted_img, binary_mask, clean_img
def process_images(self, raw, clean): i, j, h, w = RandomResizedCrop.get_params(raw, scale=(0.5, 2.0), ratio=(3. / 4., 4. / 3.)) raw_img = resized_crop(raw, i, j, h, w, size=self.img_size, interpolation=Image.BICUBIC) clean_img = resized_crop(clean, i, j, h, w, self.img_size, interpolation=Image.BICUBIC) # get mask before further image augment mask = self.get_mask(raw_img, clean_img) mask_t = to_tensor(mask) mask_t = (mask_t > 0).float() mask_t = torch.nn.functional.max_pool2d(mask_t, kernel_size=5, stride=1, padding=2) # mask_t = mask_t.byte() raw_img = ImageChops.difference(mask, clean_img) return self.transformer(raw_img), 1 - mask_t, self.transformer( clean_img)
def process_images(self, raw, clean): i, j, h, w = RandomResizedCrop.get_params(raw, scale=(0.1, 2), ratio=(3. / 4., 4. / 3.)) raw_img = resized_crop(raw, i, j, h, w, size=self.img_size, interpolation=Image.BICUBIC) raw_img = self.transformer(raw_img) # raw_img = np.array(raw_img) mask_img = resized_crop(clean, i, j, h, w, self.img_size, interpolation=Image.BICUBIC) # mask_img = np.array(mask_img) return to_tensor(raw_img), to_tensor(mask_img)
def process_images(self, raw, clean): i, j, h, w = RandomResizedCrop.get_params(raw, scale=(0.5, 2.0), ratio=(3. / 4., 4. / 3.)) raw_img = resized_crop(raw, i, j, h, w, size=self.img_size, interpolation=Image.BICUBIC) clean_img = resized_crop(clean, i, j, h, w, self.img_size, interpolation=Image.BICUBIC) # get mask before further image augment mask_tensor = self.get_mask(raw_img, clean_img) raw_img = self.transformer(raw_img) return raw_img, mask_tensor
class InitialObservationNumpyJitteringDataset(data.Dataset): def __init__(self, data, info=None): assert data['observations'].dtype == np.uint8 self.size = data['observations'].shape[0] self.traj_length = data['observations'].shape[1] self.data = data self.info = info self.jitter = ColorJitter((0.5, 1.5), (0.9, 1.1), (0.9, 1.1), (-0.1, 0.1)) self.crop = RandomResizedCrop((48, 48), (0.9, 0.9), (1, 1)) # RandomResizedCrop((int(sqrt(self.imlength)), int(sqrt(self.imlength))), (0.9, 0.9), (1, 1)) if 'env' not in self.data: self.data['env'] = self.data['observations'][:, 0, :] def __len__(self): return self.size * self.traj_length def __getitem__(self, idx): traj_i = idx // self.traj_length trans_i = idx % self.traj_length x = Image.fromarray(self.data['observations'][traj_i, trans_i].reshape( 48, 48, 3), mode='RGB') c = Image.fromarray(self.data['env'][traj_i].reshape(48, 48, 3), mode='RGB') # upsampling gives bad images so random resizing params set to 1 for now # crop = self.crop.get_params(c, (0.9, 0.9), (1, 1)) crop = self.crop.get_params(c, (1, 1), (1, 1)) # jitter = self.jitter.get_params((0.5,1.5), (0.9,1.1), (0.9,1.1), (-0.1,0.1)) jitter = self.jitter.get_params(0.5, 0.1, 0.1, 0.1) x = jitter( F.resized_crop(x, crop[0], crop[1], crop[2], crop[3], (48, 48), Image.BICUBIC)) c = jitter( F.resized_crop(c, crop[0], crop[1], crop[2], crop[3], (48, 48), Image.BICUBIC)) x_t = normalize_image(np.array(x).flatten()).squeeze() env = normalize_image(np.array(c).flatten()).squeeze() data_dict = { 'x_t': x_t, 'env': env, } return data_dict
def process_images(self, raw, clean): i, j, h, w = RandomResizedCrop.get_params(raw, scale=(0.5, 2.0), ratio=(3. / 4., 4. / 3.)) raw_img = resized_crop(raw, i, j, h, w, size=self.img_size, interpolation=Image.BICUBIC) if self.add_random_masks: raw_img = self.random_mask.draw(raw_img) clean_img = resized_crop(clean, i, j, h, w, self.img_size, interpolation=Image.BICUBIC) # get mask before further image augment mask = self.get_mask(raw_img, clean_img) mask_t = to_tensor(mask) mask_t = (mask_t > brightness_difference).float() mask_t = torch.max(mask_t, dim=0, keepdim=True) mask_t = torch.nn.functional.max_pool2d(mask_t, kernel_size=9, stride=1, padding=4) # corrupt the clean images rather than using the raw ones binary_mask = (1 - mask_t) # valid positions are 1; holes are 0 binary_mask = binary_mask.expand(3, -1, -1) clean_img = self.transformer(clean_img) corrupted_img = clean_img * binary_mask return corrupted_img, binary_mask, clean_img
def process_images(self, clean, mask): i, j, h, w = RandomResizedCrop.get_params(clean, scale=(0.5, 2.0), ratio=(3. / 4., 4. / 3.)) clean_img = resized_crop(clean, i, j, h, w, size=self.img_size, interpolation=Image.BICUBIC) mask = resized_crop(mask, i, j, h, w, self.img_size, interpolation=Image.BICUBIC) # get mask before further image augment # mask = self.get_mask(raw_img, clean_img) if self.add_random_masks: mask = random_masks(mask.copy(), size=self.img_size[0], offset=10) mask = np.where( np.array(mask) > brightness_difference * 255, np.uint8(255), np.uint8(0)) mask = cv2.dilate(mask, np.ones((10, 10), np.uint8), iterations=1) mask = np.expand_dims(mask, -1) mask_t = to_tensor(mask) # mask_t = (mask_t > brightness_difference).float() # mask_t, _ = torch.max(mask_t, dim=0, keepdim=True) binary_mask = (1 - mask_t) # valid positions are 1; holes are 0 binary_mask = binary_mask.expand(3, -1, -1) clean_img = self.transformer(clean_img) corrupted_img = clean_img * binary_mask return corrupted_img, binary_mask, clean_img
def random_resize_crop(augment_targets, scale, ratio, size, threshold, pre_crop_area=None): image, region_score, affinity_score, confidence_mask = augment_targets image = Image.fromarray(image) region_score = Image.fromarray(region_score) affinity_score = Image.fromarray(affinity_score) confidence_mask = Image.fromarray(confidence_mask) if pre_crop_area != None: i, j, h, w = pre_crop_area else: if random.random() < threshold: i, j, h, w = RandomResizedCrop.get_params(image, scale=scale, ratio=ratio) else: i, j, h, w = RandomResizedCrop.get_params(image, scale=(1.0, 1.0), ratio=(1.0, 1.0)) image = resized_crop(image, i, j, h, w, size=(size, size), interpolation=InterpolationMode.BICUBIC) region_score = resized_crop(region_score, i, j, h, w, (size, size), interpolation=InterpolationMode.BICUBIC) affinity_score = resized_crop( affinity_score, i, j, h, w, (size, size), interpolation=InterpolationMode.BICUBIC, ) confidence_mask = resized_crop( confidence_mask, i, j, h, w, (size, size), interpolation=InterpolationMode.NEAREST, ) image = np.array(image) region_score = np.array(region_score) affinity_score = np.array(affinity_score) confidence_mask = np.array(confidence_mask) augment_targets = [image, region_score, affinity_score, confidence_mask] return augment_targets