def apply_transform( # type: ignore self, input: torch.Tensor, label: torch.Tensor, params: Dict[str, torch.Tensor] # type: ignore ) -> Tuple[torch.Tensor, torch.Tensor]: height, width = input.size(2), input.size(3) num_mixes = params['mix_pairs'].size(0) batch_size = params['mix_pairs'].size(1) _shape_validation(params['mix_pairs'], [num_mixes, batch_size], 'mix_pairs') _shape_validation(params['crop_src'], [num_mixes, batch_size, 4, 2], 'crop_src') out_inputs = input.clone() out_labels = [] for pair, crop in zip(params['mix_pairs'], params['crop_src']): input_permute = input.index_select(dim=0, index=pair.to(input.device)) labels_permute = label.index_select(dim=0, index=pair.to(label.device)) w, h = infer_bbox_shape(crop) lam = w.to(input.dtype) * h.to(input.dtype) / (width * height) # width_beta * height_beta # compute mask to match input shape mask = bbox_to_mask(crop, width, height).bool().unsqueeze(dim=1).repeat(1, input.size(1), 1, 1) out_inputs[mask] = input_permute[mask] out_labels.append( torch.stack([label.to(input.dtype), labels_permute.to(input.dtype), lam.to(label.device)], dim=1) ) return out_inputs, torch.stack(out_labels, dim=0)
def test_inverse_and_forward_return_transform(self, random_apply, device, dtype): inp = torch.randn(1, 3, 1000, 500, device=device, dtype=dtype) bbox = torch.tensor([[[355, 10], [660, 10], [660, 250], [355, 250]]], device=device, dtype=dtype) keypoints = torch.tensor([[[465, 115], [545, 116]]], device=device, dtype=dtype) mask = bbox_to_mask( torch.tensor([[[155, 0], [900, 0], [900, 400], [155, 400]]], device=device, dtype=dtype), 1000, 500)[:, None].float() aug = K.AugmentationSequential( K.ImageSequential( K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomAffine(360, p=1.0, return_transform=True)), K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0, return_transform=True), K.RandomAffine(360, p=1.0, return_transform=True), data_keys=["input", "mask", "bbox", "keypoints"], random_apply=random_apply, ) with pytest.raises( Exception): # No parameters available for inversing. aug.inverse(inp, mask, bbox, keypoints) out = aug(inp, mask, bbox, keypoints) assert out[0][0].shape == inp.shape assert out[1].shape == mask.shape assert out[2].shape == bbox.shape assert out[3].shape == keypoints.shape reproducibility_test((inp, mask, bbox, keypoints), aug)
def test_individual_forward_and_inverse(self, device, dtype): inp = torch.randn(1, 3, 1000, 500, device=device, dtype=dtype) bbox = torch.tensor([[[355, 10], [660, 10], [660, 250], [355, 250]]], device=device, dtype=dtype) keypoints = torch.tensor([[[465, 115], [545, 116]]], device=device, dtype=dtype) mask = bbox_to_mask( torch.tensor([[[155, 0], [900, 0], [900, 400], [155, 400]]], device=device, dtype=dtype), 1000, 500)[:, None].float() aug = K.AugmentationSequential( K.RandomAffine(360, p=1.0, return_transform=False), data_keys=['input', 'mask', 'bbox', 'keypoints']) reproducibility_test((inp, mask, bbox, keypoints), aug) aug = K.AugmentationSequential( K.RandomAffine(360, p=1.0, return_transform=True)) assert aug(inp, data_keys=['input'])[0].shape == inp.shape aug = K.AugmentationSequential( K.RandomAffine(360, p=1.0, return_transform=False)) assert aug(inp, data_keys=['input']).shape == inp.shape assert aug(mask, data_keys=['mask'], params=aug._params).shape == mask.shape assert aug.inverse(inp, data_keys=['input']).shape == inp.shape assert aug.inverse(bbox, data_keys=['bbox']).shape == bbox.shape assert aug.inverse(keypoints, data_keys=['keypoints']).shape == keypoints.shape assert aug.inverse(mask, data_keys=['mask']).shape == mask.shape
def test_forward_and_inverse(self, random_apply, return_transform, device, dtype): inp = torch.randn(1, 3, 1000, 500, device=device, dtype=dtype) bbox = torch.tensor([[[355, 10], [660, 10], [660, 250], [355, 250]]], device=device, dtype=dtype) keypoints = torch.tensor([[[465, 115], [545, 116]]], device=device, dtype=dtype) mask = bbox_to_mask( torch.tensor([[[155, 0], [900, 0], [900, 400], [155, 400]]], device=device, dtype=dtype), 1000, 500)[:, None].float() aug = K.AugmentationSequential( K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomAffine(360, p=1.0), data_keys=["input", "mask", "bbox", "keypoints"], random_apply=random_apply, return_transform=return_transform, ) out = aug(inp, mask, bbox, keypoints) if return_transform and isinstance(out, (tuple, list)): assert out[0][0].shape == inp.shape else: assert out[0].shape == inp.shape assert out[1].shape == mask.shape assert out[2].shape == bbox.shape assert out[3].shape == keypoints.shape reproducibility_test((inp, mask, bbox, keypoints), aug) out_inv = aug.inverse(*out) assert out_inv[0].shape == inp.shape assert out_inv[1].shape == mask.shape assert out_inv[2].shape == bbox.shape assert out_inv[3].shape == keypoints.shape
def test_individual_forward_and_inverse(self, device, dtype): inp = torch.randn(1, 3, 1000, 500, device=device, dtype=dtype) bbox = torch.tensor([[[355, 10], [660, 10], [660, 250], [355, 250]]], device=device, dtype=dtype) keypoints = torch.tensor([[[465, 115], [545, 116]]], device=device, dtype=dtype) mask = bbox_to_mask( torch.tensor([[[155, 0], [900, 0], [900, 400], [155, 400]]], device=device, dtype=dtype), 500, 1000)[:, None].float() crop_size = (200, 200) aug = K.AugmentationSequential( K.ImageSequential(K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomAffine(360, p=1.0)), K.AugmentationSequential(K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomAffine(360, p=1.0)), K.RandomAffine(360, p=1.0), K.RandomCrop(crop_size, padding=1, cropping_mode='resample', fill=0), data_keys=['input', 'mask', 'bbox', 'keypoints'], ) reproducibility_test((inp, mask, bbox, keypoints), aug) out = aug(inp, mask, bbox, keypoints) assert out[0].shape == (*inp.shape[:2], *crop_size) assert out[1].shape == (*mask.shape[:2], *crop_size) assert out[2].shape == bbox.shape assert out[3].shape == keypoints.shape out_inv = aug.inverse(*out) assert out_inv[0].shape == inp.shape assert out_inv[1].shape == mask.shape assert out_inv[2].shape == bbox.shape assert out_inv[3].shape == keypoints.shape aug = K.AugmentationSequential(K.RandomAffine(360, p=1.0)) assert aug(inp, data_keys=['input']).shape == inp.shape aug = K.AugmentationSequential(K.RandomAffine(360, p=1.0)) assert aug(inp, data_keys=['input']).shape == inp.shape assert aug(mask, data_keys=['mask'], params=aug._params).shape == mask.shape assert aug.inverse(inp, data_keys=['input']).shape == inp.shape assert aug.inverse(bbox, data_keys=['bbox']).shape == bbox.shape assert aug.inverse(keypoints, data_keys=['keypoints']).shape == keypoints.shape assert aug.inverse(mask, data_keys=['mask']).shape == mask.shape
def apply_transform(self, input: Tensor, params: Dict[str, Tensor], transform: Optional[Tensor] = None) -> Tensor: _, c, h, w = input.size() values = params["values"].unsqueeze(-1).unsqueeze(-1).unsqueeze( -1).repeat(1, *input.shape[1:]).to(input) bboxes = bbox_generator(params["xs"], params["ys"], params["widths"], params["heights"]) mask = bbox_to_mask(bboxes, w, h) # Returns B, H, W mask = mask.unsqueeze(1).repeat(1, c, 1, 1).to(input) # Transform to B, c, H, W transformed = torch.where(mask == 1.0, values, input) return transformed
def test_forward_and_inverse_return_transform(self, random_apply, device, dtype): inp = torch.randn(1, 3, 1000, 500, device=device, dtype=dtype) bbox = torch.tensor([[[355, 10], [660, 10], [660, 250], [355, 250]]], device=device, dtype=dtype) keypoints = torch.tensor([[[465, 115], [545, 116]]], device=device, dtype=dtype) mask = bbox_to_mask( torch.tensor([[[155, 0], [900, 0], [900, 400], [155, 400]]], device=device, dtype=dtype), 1000, 500)[:, None].float() aug = K.AugmentationSequential( K.ImageSequential(K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomAffine(360, p=1.0)), K.AugmentationSequential(K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomAffine(360, p=1.0)), K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomAffine(360, p=1.0), data_keys=["input", "mask", "bbox", "keypoints"], random_apply=random_apply, ) out = aug(inp, mask, bbox, keypoints) assert out[0].shape == inp.shape assert out[1].shape == mask.shape assert out[2].shape == bbox.shape assert out[3].shape == keypoints.shape reproducibility_test((inp, mask, bbox, keypoints), aug) # TODO(jian): we sometimes throw the following error # AttributeError: 'tuple' object has no attribute 'shape' out_inv = aug.inverse(*out) assert out_inv[0].shape == inp.shape assert out_inv[1].shape == mask.shape assert out_inv[2].shape == bbox.shape assert out_inv[3].shape == keypoints.shape