def get_train_augmentation_transforms( p_flip_vertical=0.5, p_flip_horizontal=0.5, max_rotation=10.0, max_zoom=1.1, max_warp=0.2, p_affine=0.75, max_lighting=0.2, p_lighting=0.75, ): """ Build a set of pytorch image Transforms to use during training: p_flip_vertical: probability of a vertical flip p_flip_horizontal: probability of a horizontal flip max_rotation: maximum rotation angle in degrees max_zoom: maximum zoom level max_warp: perspective warping scale (from 0.0 to 1.0) p_affine: probility of rotation, zoom and perspective warping max_lighting: maximum scaling of brightness and contrast """ return [ kaug.RandomVerticalFlip(p=p_flip_vertical), kaug.RandomHorizontalFlip(p=p_flip_horizontal), kaug.RandomAffine(p=p_affine, degrees=max_rotation, scale=(1.0, max_zoom)), kaug.RandomPerspective(p=p_affine, distortion_scale=max_warp), kaug.ColorJitter(p=p_lighting, brightness=max_lighting, contrast=max_lighting), # TODO: the kornia transforms work on batches and so by default they # add a batch dimension. Until I work out how to apply transform by # batches (and only while training) I will just keep this here to # remove the batch dimension again GetItemTransform(), ]
def test_inverse_and_forward_return_transform(self, device, dtype): inp = torch.randn(1, 3, 1000, 500, device=device, dtype=dtype) bbox = torch.tensor([[[355, 10], [660, 10], [660, 250], [355, 250]]], device=device, dtype=dtype) keypoints = torch.tensor([[[465, 115], [545, 116]]], device=device, dtype=dtype) mask = bbox_to_mask( torch.tensor([[[155, 0], [900, 0], [900, 400], [155, 400]]], device=device, dtype=dtype), 1000, 500)[:, None].float() aug = K.AugmentationSequential( K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0, return_transform=True), K.RandomAffine(360, p=1.0, return_transform=True), data_keys=["input", "mask", "bbox", "keypoints"], ) out_inv = aug.inverse(inp, mask, bbox, keypoints) assert out_inv[0].shape == inp.shape assert out_inv[1].shape == mask.shape assert out_inv[2].shape == bbox.shape assert out_inv[3].shape == keypoints.shape out = aug(inp, mask, bbox, keypoints) assert out[0][0].shape == inp.shape assert out[1].shape == mask.shape assert out[2].shape == bbox.shape assert out[3].shape == keypoints.shape
def test_inverse_and_forward_return_transform(self, random_apply, device, dtype): inp = torch.randn(1, 3, 1000, 500, device=device, dtype=dtype) bbox = torch.tensor([[[355, 10], [660, 10], [660, 250], [355, 250]]], device=device, dtype=dtype) keypoints = torch.tensor([[[465, 115], [545, 116]]], device=device, dtype=dtype) mask = bbox_to_mask( torch.tensor([[[155, 0], [900, 0], [900, 400], [155, 400]]], device=device, dtype=dtype), 1000, 500)[:, None].float() aug = K.AugmentationSequential( K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0, return_transform=True), K.RandomAffine(360, p=1.0, return_transform=True), data_keys=["input", "mask", "bbox", "keypoints"], random_apply=random_apply, ) with pytest.raises( Exception): # No parameters avaliable for inversing. aug.inverse(inp, mask, bbox, keypoints) out = aug(inp, mask, bbox, keypoints) assert out[0][0].shape == inp.shape assert out[1].shape == mask.shape assert out[2].shape == bbox.shape assert out[3].shape == keypoints.shape reproducibility_test((inp, mask, bbox, keypoints), aug)
def test_forward_and_inverse(self, random_apply, return_transform, device, dtype): inp = torch.randn(1, 3, 1000, 500, device=device, dtype=dtype) bbox = torch.tensor([[[355, 10], [660, 10], [660, 250], [355, 250]]], device=device, dtype=dtype) keypoints = torch.tensor([[[465, 115], [545, 116]]], device=device, dtype=dtype) mask = bbox_to_mask( torch.tensor([[[155, 0], [900, 0], [900, 400], [155, 400]]], device=device, dtype=dtype), 1000, 500)[:, None].float() aug = K.AugmentationSequential( K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomAffine(360, p=1.0), data_keys=["input", "mask", "bbox", "keypoints"], random_apply=random_apply, return_transform=return_transform, ) out = aug(inp, mask, bbox, keypoints) if return_transform and isinstance(out, (tuple, list)): assert out[0][0].shape == inp.shape else: assert out[0].shape == inp.shape assert out[1].shape == mask.shape assert out[2].shape == bbox.shape assert out[3].shape == keypoints.shape reproducibility_test((inp, mask, bbox, keypoints), aug) out_inv = aug.inverse(*out) assert out_inv[0].shape == inp.shape assert out_inv[1].shape == mask.shape assert out_inv[2].shape == bbox.shape assert out_inv[3].shape == keypoints.shape
def __init__(self, mean, std, scale=(0.9, 1.1), max_degrees=0) -> None: super(Transform, self).__init__() self.max_degrees = max_degrees self.aff = k.RandomAffine(max_degrees, resample=k.Resample.NEAREST, scale=scale) self.norm = k.Normalize(mean, std)
def test_mixup(self, inp, random_apply, same_on_batch, device, dtype): inp = torch.as_tensor(inp, device=device, dtype=dtype) aug = K.AugmentationSequential( K.ImageSequential(K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomAffine(360, p=1.0)), K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomAffine(360, p=1.0), K.RandomMixUp(p=1.0), data_keys=["input"], random_apply=random_apply, same_on_batch=same_on_batch, ) out = aug(inp) if aug.return_label: out, _ = out assert out.shape[-3:] == inp.shape[-3:] reproducibility_test(inp, aug)
def test_construction(self, same_on_batch, return_transform, keepdim): K.ImageSequential( K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomAffine(360, p=1.0), same_on_batch=same_on_batch, return_transform=return_transform, keepdim=keepdim, )
def test_jit(self, device, dtype): B, C, H, W = 2, 3, 4, 4 img = torch.ones(B, C, H, W, device=device, dtype=dtype) op = K.AugmentationSequential(K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomAffine(360, p=1.0), same_on_batch=True) op_jit = torch.jit.script(op) assert_close(op(img), op_jit(img))
def test_forward(self, random_apply, device, dtype): inp = torch.randn(1, 3, 30, 30, device=device, dtype=dtype) aug = K.ImageSequential( K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0), kornia.filters.MedianBlur((3, 3)), K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0), K.ImageSequential(K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0)), K.ImageSequential(K.RandomAffine(360, p=1.0)), K.RandomAffine(360, p=1.0), K.RandomMixUp(p=1.0), random_apply=random_apply, ) out = aug(inp) if aug.return_label: out, _ = out assert out.shape == inp.shape aug.inverse(inp) reproducibility_test(inp, aug)
def __init__(self, cutn): super().__init__() self.cutn = cutn self.augs = nn.Sequential( K.RandomHorizontalFlip(p=0.5), K.ColorJitter(hue=0.01, saturation=0.01, p=0.7), #K.RandomSolarize(0.01, 0.01, p=0.7), K.RandomSharpness(0.3, p=0.4), K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode='border'), K.RandomPerspective(0.2, p=0.4), ) self.noise_fac = 0.1
def test_transform_kornia(): """Run few epochs to check ``BatchTransformCallback`` callback.""" model = torch.nn.Linear(28 * 28, 10) optimizer = torch.optim.Adam(model.parameters(), lr=0.02) loaders = { "train": DataLoader( MnistDataset( MNIST(os.getcwd(), train=False, download=True, transform=ToTensor())), batch_size=32, ), "valid": DataLoader( MnistDataset( MNIST(os.getcwd(), train=False, download=True, transform=ToTensor())), batch_size=32, ), } transrorms = [ augmentation.RandomAffine(degrees=(-15, 20), scale=(0.75, 1.25)), ] runner = CustomRunner() # model training runner.train( model=model, optimizer=optimizer, loaders=loaders, logdir="./logs", num_epochs=5, verbose=False, load_best_on_end=True, check=True, callbacks=[ BatchTransformCallback(transform=transrorms, scope="on_batch_start", input_key="features") ], ) # model inference for prediction in runner.predict_loader(loader=loaders["train"]): assert prediction.detach().cpu().numpy().shape[-1] == 10
def test_forward(self, return_transform, random_apply, device, dtype): inp = torch.randn(1, 3, 30, 30, device=device, dtype=dtype) aug = K.ImageSequential( K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), kornia.filters.MedianBlur((3, 3)), K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0, return_transform=True), K.ImageSequential(K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0)), K.ImageSequential(K.RandomAffine(360, p=1.0)), K.RandomAffine(360, p=1.0), K.RandomMixUp(p=1.0), return_transform=return_transform, random_apply=random_apply, ) out = aug(inp) if aug.return_label: out, _ = out if isinstance(out, (tuple, )): assert out[0].shape == inp.shape else: assert out.shape == inp.shape aug.inverse(inp) reproducibility_test(inp, aug)
def test_mixup(self, inp, return_transform, random_apply, same_on_batch, device, dtype): inp = torch.as_tensor(inp, device=device, dtype=dtype) aug = K.AugmentationSequential( K.ImageSequential( K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomAffine(360, p=1.0, return_transform=True)), K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomAffine(360, p=1.0), K.RandomMixUp(p=1.0), data_keys=["input"], random_apply=random_apply, return_transform=return_transform, same_on_batch=same_on_batch, ) out = aug(inp) if aug.return_label: out, _ = out if return_transform and isinstance(out, (tuple, list)): out = out[0] assert out.shape[-3:] == inp.shape[-3:] reproducibility_test(inp, aug)
def test_individual_forward_and_inverse(self, device, dtype): inp = torch.randn(1, 3, 1000, 500, device=device, dtype=dtype) bbox = torch.tensor([[[355, 10], [660, 10], [660, 250], [355, 250]]], device=device, dtype=dtype) keypoints = torch.tensor([[[465, 115], [545, 116]]], device=device, dtype=dtype) mask = bbox_to_mask( torch.tensor([[[155, 0], [900, 0], [900, 400], [155, 400]]], device=device, dtype=dtype), 1000, 500)[:, None].float() aug = K.AugmentationSequential( K.ImageSequential( K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomAffine(360, p=1.0, return_transform=True)), K.AugmentationSequential( K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomAffine(360, p=1.0, return_transform=True)), K.RandomAffine(360, p=1.0, return_transform=False), data_keys=['input', 'mask', 'bbox', 'keypoints'], ) reproducibility_test((inp, mask, bbox, keypoints), aug) aug = K.AugmentationSequential( K.RandomAffine(360, p=1.0, return_transform=True)) assert aug(inp, data_keys=['input'])[0].shape == inp.shape aug = K.AugmentationSequential( K.RandomAffine(360, p=1.0, return_transform=False)) assert aug(inp, data_keys=['input']).shape == inp.shape assert aug(mask, data_keys=['mask'], params=aug._params).shape == mask.shape assert aug.inverse(inp, data_keys=['input']).shape == inp.shape assert aug.inverse(bbox, data_keys=['bbox']).shape == bbox.shape assert aug.inverse(keypoints, data_keys=['keypoints']).shape == keypoints.shape assert aug.inverse(mask, data_keys=['mask']).shape == mask.shape
def test_forward_and_inverse_return_transform(self, random_apply, device, dtype): inp = torch.randn(1, 3, 1000, 500, device=device, dtype=dtype) bbox = torch.tensor([[[355, 10], [660, 10], [660, 250], [355, 250]]], device=device, dtype=dtype) keypoints = torch.tensor([[[465, 115], [545, 116]]], device=device, dtype=dtype) mask = bbox_to_mask( torch.tensor([[[155, 0], [900, 0], [900, 400], [155, 400]]], device=device, dtype=dtype), 1000, 500)[:, None].float() aug = K.AugmentationSequential( K.ImageSequential(K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomAffine(360, p=1.0)), K.AugmentationSequential(K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomAffine(360, p=1.0)), K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomAffine(360, p=1.0), data_keys=["input", "mask", "bbox", "keypoints"], random_apply=random_apply, ) out = aug(inp, mask, bbox, keypoints) assert out[0].shape == inp.shape assert out[1].shape == mask.shape assert out[2].shape == bbox.shape assert out[3].shape == keypoints.shape reproducibility_test((inp, mask, bbox, keypoints), aug) # TODO(jian): we sometimes throw the following error # AttributeError: 'tuple' object has no attribute 'shape' out_inv = aug.inverse(*out) assert out_inv[0].shape == inp.shape assert out_inv[1].shape == mask.shape assert out_inv[2].shape == bbox.shape assert out_inv[3].shape == keypoints.shape
def test_forward(self, return_transform, device, dtype): inp = torch.randn(1, 3, 30, 30, device=device, dtype=dtype) aug = K.ImageSequential( K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), kornia.filters.MedianBlur((3, 3)), K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0, return_transform=True), K.RandomAffine(360, p=1.0), return_transform=return_transform, ) out = aug(inp) if isinstance(out, (tuple, )): assert out[0].shape == inp.shape else: assert out.shape == inp.shape
def test_construction(self, same_on_batch, keepdim, random_apply): aug = K.ImageSequential( K.ColorJiggle(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomAffine(360, p=1.0), K.RandomMixUp(p=1.0), same_on_batch=same_on_batch, keepdim=keepdim, random_apply=random_apply, ) c = 0 for a in aug.get_forward_sequence(): if isinstance(a, (MixAugmentationBase, )): c += 1 assert c < 2 aug.same_on_batch = True aug.keepdim = True for m in aug.children(): assert m.same_on_batch is True, m.same_on_batch assert m.keepdim is True, m.keepdim
def __init__(self, utes, mask, ct, length, opt): super(TrainDataset, self).__init__() self.utes = utes self.label = ct self.mask = mask self.length = length self.num_vols = utes.shape[0] self.batch_size = opt.trainBatchSize self.spatial = nn.Sequential( ka.RandomAffine(45, translate=(0.1, 0.1), scale=(0.85, 1.15), shear=(0.1, 0.1), same_on_batch=True), ka.RandomVerticalFlip(same_on_batch=True), ka.RandomHorizontalFlip(same_on_batch=True)) self.dim = 2 self.counter = 0
def __init__(self, brightness=(0.75, 1.25), contrast=(0.75, 1.25), saturation=(0., 2.), translate=(0.125, 0.125), normalized=True, mean=0.5, std=0.5, device=None): if normalized: if isinstance(mean, (tuple, list)) and isinstance(std, (tuple, list)): if not device: raise Exception( 'Please specify a torch.device() object when using mean and std for each channels' ) mean = torch.Tensor(mean).to(device) std = torch.Tensor(std).to(device) self.normalize = aug.Normalize(mean, std) self.denormalize = aug.Denormalize(mean, std) else: self.normalize, self.denormalize = None, None color_jitter = aug.ColorJitter( brightness=brightness, contrast=contrast, saturation=saturation, p=1.) # rand_brightness, rand_contrast, rand_saturation affine = aug.RandomAffine(degrees=0, translate=translate, padding_mode=SamplePadding.BORDER, p=1.) # rand_translate cutout = aug.RandomErasing(value=0.5, p=1.) # rand_cutout self.augmentations = { 'color': color_jitter, 'translation': affine, 'cutout': cutout }
def __init__(self, cut_size, cutn, cut_pow=1.): super().__init__() self.cut_size = cut_size self.cutn = cutn self.cut_pow = cut_pow self.augs = nn.Sequential( # K.RandomHorizontalFlip(p=0.5), # K.RandomVerticalFlip(p=0.5), # K.RandomSolarize(0.01, 0.01, p=0.7), # K.RandomSharpness(0.3, p=0.4), # K.RandomResizedCrop( # size=(self.cut_size, self.cut_size), # scale=(0.1, 1), ratio=(0.75, 1.333), # cropping_mode="resample", p=0.5 # ), # K.RandomCrop( # size=(self.cut_size, self.cut_size), p=0.5 # ), K.RandomAffine( degrees=15, translate=0.1, p=0.7, padding_mode="border" ), K.RandomPerspective(0.7, p=0.7), K.ColorJitter(hue=0.1, saturation=0.1, p=0.7), K.RandomErasing( (.1, .4), (.3, 1/.3), same_on_batch=True, p=0.7 ), ) self.noise_fac = 0.1 self.av_pool = nn.AdaptiveAvgPool2d( (self.cut_size, self.cut_size) ) self.max_pool = nn.AdaptiveMaxPool2d( (self.cut_size, self.cut_size) )
mode="validation", download=True) transform = { "per_sample_transform": nn.Sequential( ApplyToKeys( DataKeys.INPUT, nn.Sequential( torchvision.transforms.ToTensor(), Kg.Resize((196, 196)), # SPATIAL Ka.RandomHorizontalFlip(p=0.25), Ka.RandomRotation(degrees=90.0, p=0.25), Ka.RandomAffine(degrees=1 * 5.0, shear=1 / 5, translate=1 / 20, p=0.25), Ka.RandomPerspective(distortion_scale=1 / 25, p=0.25), # PIXEL-LEVEL Ka.ColorJitter(brightness=1 / 30, p=0.25), # brightness Ka.ColorJitter(saturation=1 / 30, p=0.25), # saturation Ka.ColorJitter(contrast=1 / 30, p=0.25), # contrast Ka.ColorJitter(hue=1 / 30, p=0.25), # hue Ka.RandomMotionBlur(kernel_size=2 * (4 // 3) + 1, angle=1, direction=1.0, p=0.25), Ka.RandomErasing(scale=(1 / 100, 1 / 50), ratio=(1 / 20, 1), p=0.25), ),
class TestVideoSequential: @pytest.mark.parametrize('shape', [(3, 4), (2, 3, 4), (2, 3, 5, 6), (2, 3, 4, 5, 6, 7)]) @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"]) def test_exception(self, shape, data_format, device, dtype): aug_list = K.VideoSequential(K.ColorJitter(0.1, 0.1, 0.1, 0.1), data_format=data_format, same_on_frame=True) with pytest.raises(AssertionError): img = torch.randn(*shape, device=device, dtype=dtype) aug_list(img) @pytest.mark.parametrize( 'augmentation', [ K.RandomAffine(360, p=1.0), K.CenterCrop((3, 3), p=1.0), K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomCrop((5, 5), p=1.0), K.RandomErasing(p=1.0), K.RandomGrayscale(p=1.0), K.RandomHorizontalFlip(p=1.0), K.RandomVerticalFlip(p=1.0), K.RandomPerspective(p=1.0), K.RandomResizedCrop((5, 5), p=1.0), K.RandomRotation(360.0, p=1.0), K.RandomSolarize(p=1.0), K.RandomPosterize(p=1.0), K.RandomSharpness(p=1.0), K.RandomEqualize(p=1.0), K.RandomMotionBlur(3, 35.0, 0.5, p=1.0), K.Normalize(torch.tensor([0.5, 0.5, 0.5]), torch.tensor([0.5, 0.5, 0.5]), p=1.0), K.Denormalize(torch.tensor([0.5, 0.5, 0.5]), torch.tensor([0.5, 0.5, 0.5]), p=1.0), ], ) @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"]) def test_augmentation(self, augmentation, data_format, device, dtype): input = torch.randint(255, (1, 3, 3, 5, 6), device=device, dtype=dtype).repeat(2, 1, 1, 1, 1) / 255.0 torch.manual_seed(21) aug_list = K.VideoSequential(augmentation, data_format=data_format, same_on_frame=True) reproducibility_test(input, aug_list) @pytest.mark.parametrize( 'augmentations', [ [ K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomAffine(360, p=1.0) ], [ K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0) ], [K.RandomAffine(360, p=1.0), kornia.color.BgrToRgb()], [ K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=0.0), K.RandomAffine(360, p=0.0) ], [K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=0.0)], [K.RandomAffine(360, p=0.0)], [ K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomAffine(360, p=1.0), K.RandomMixUp(p=1.0) ], ], ) @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"]) @pytest.mark.parametrize('random_apply', [1, (1, 1), (1, ), 10, True, False]) def test_same_on_frame(self, augmentations, data_format, random_apply, device, dtype): aug_list = K.VideoSequential(*augmentations, data_format=data_format, same_on_frame=True, random_apply=random_apply) if data_format == 'BCTHW': input = torch.randn(2, 3, 1, 5, 6, device=device, dtype=dtype).repeat(1, 1, 4, 1, 1) output = aug_list(input) if aug_list.return_label: output, _ = output assert (output[:, :, 0] == output[:, :, 1]).all() assert (output[:, :, 1] == output[:, :, 2]).all() assert (output[:, :, 2] == output[:, :, 3]).all() if data_format == 'BTCHW': input = torch.randn(2, 1, 3, 5, 6, device=device, dtype=dtype).repeat(1, 4, 1, 1, 1) output = aug_list(input) if aug_list.return_label: output, _ = output assert (output[:, 0] == output[:, 1]).all() assert (output[:, 1] == output[:, 2]).all() assert (output[:, 2] == output[:, 3]).all() reproducibility_test(input, aug_list) @pytest.mark.parametrize( 'augmentations', [ [K.RandomAffine(360, p=1.0)], [K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0)], [ K.RandomAffine(360, p=0.0), K.ImageSequential(K.RandomAffine(360, p=0.0)) ], ], ) @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"]) def test_against_sequential(self, augmentations, data_format, device, dtype): aug_list_1 = K.VideoSequential(*augmentations, data_format=data_format, same_on_frame=False) aug_list_2 = torch.nn.Sequential(*augmentations) if data_format == 'BCTHW': input = torch.randn(2, 3, 1, 5, 6, device=device, dtype=dtype).repeat(1, 1, 4, 1, 1) if data_format == 'BTCHW': input = torch.randn(2, 1, 3, 5, 6, device=device, dtype=dtype).repeat(1, 4, 1, 1, 1) torch.manual_seed(0) output_1 = aug_list_1(input) torch.manual_seed(0) if data_format == 'BCTHW': input = input.transpose(1, 2) output_2 = aug_list_2(input.reshape(-1, 3, 5, 6)) output_2 = output_2.view(2, 4, 3, 5, 6) if data_format == 'BCTHW': output_2 = output_2.transpose(1, 2) assert (output_1 == output_2).all(), dict(aug_list_1._params) @pytest.mark.jit @pytest.mark.skip(reason="turn off due to Union Type") def test_jit(self, device, dtype): B, C, D, H, W = 2, 3, 5, 4, 4 img = torch.ones(B, C, D, H, W, device=device, dtype=dtype) op = K.VideoSequential(K.ColorJitter(0.1, 0.1, 0.1, 0.1), same_on_frame=True) op_jit = torch.jit.script(op) assert_close(op(img), op_jit(img))
def get_augmenter(augmenter_type: str, image_size: ImageSizeType, dataset_mean: DatasetStatType, dataset_std: DatasetStatType, padding: PaddingInputType = 1. / 8., pad_if_needed: bool = False, subset_size: int = 2) -> Union[Module, Callable]: """ Args: augmenter_type: augmenter type image_size: (height, width) image size dataset_mean: dataset mean value in CHW dataset_std: dataset standard deviation in CHW padding: percent of image size to pad on each border of the image. If a sequence of length 4 is provided, it is used to pad left, top, right, bottom borders respectively. If a sequence of length 2 is provided, it is used to pad left/right, top/bottom borders, respectively. pad_if_needed: bool flag for RandomCrop "pad_if_needed" option subset_size: number of augmentations used in subset Returns: nn.Module for Kornia augmentation or Callable for torchvision transform """ if not isinstance(padding, tuple): assert isinstance(padding, float) padding = (padding, padding, padding, padding) assert len(padding) == 2 or len(padding) == 4 if len(padding) == 2: # padding of length 2 is used to pad left/right, top/bottom borders, respectively # padding of length 4 is used to pad left, top, right, bottom borders respectively padding = (padding[0], padding[1], padding[0], padding[1]) # image_size is of shape (h,w); padding values is [left, top, right, bottom] borders padding = (int(image_size[1] * padding[0]), int( image_size[0] * padding[1]), int(image_size[1] * padding[2]), int(image_size[0] * padding[3])) augmenter_type = augmenter_type.strip().lower() if augmenter_type == "simple": return nn.Sequential( K.RandomCrop(size=image_size, padding=padding, pad_if_needed=pad_if_needed, padding_mode='reflect'), K.RandomHorizontalFlip(p=0.5), K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32), std=torch.tensor(dataset_std, dtype=torch.float32)), ) elif augmenter_type == "fixed": return nn.Sequential( K.RandomHorizontalFlip(p=0.5), # K.RandomVerticalFlip(p=0.2), K.RandomResizedCrop(size=image_size, scale=(0.8, 1.0), ratio=(1., 1.)), RandomAugmentation(p=0.5, augmentation=F.GaussianBlur2d( kernel_size=(3, 3), sigma=(1.5, 1.5), border_type='constant')), K.ColorJitter(contrast=(0.75, 1.5)), # additive Gaussian noise K.RandomErasing(p=0.1), # Multiply K.RandomAffine(degrees=(-25., 25.), translate=(0.2, 0.2), scale=(0.8, 1.2), shear=(-8., 8.)), K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32), std=torch.tensor(dataset_std, dtype=torch.float32)), ) elif augmenter_type in ["validation", "test"]: return nn.Sequential( K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32), std=torch.tensor(dataset_std, dtype=torch.float32)), ) elif augmenter_type == "randaugment": return nn.Sequential( K.RandomCrop(size=image_size, padding=padding, pad_if_needed=pad_if_needed, padding_mode='reflect'), K.RandomHorizontalFlip(p=0.5), RandAugmentNS(n=subset_size, m=10), K.Normalize(mean=torch.tensor(dataset_mean, dtype=torch.float32), std=torch.tensor(dataset_std, dtype=torch.float32)), ) else: raise NotImplementedError( f"\"{augmenter_type}\" is not a supported augmenter type")
def shearX(data, opt): ra = K.RandomAffine(degrees=0, shear=(-20, 20)) out = ra(data.view([-1] + list(data.shape[-3:]))) return out.view(data.shape)
def test_random_crops(self, device, dtype): torch.manual_seed(233) input = torch.randn(3, 3, 3, 3, device=device, dtype=dtype) bbox = torch.tensor([[[1.0, 1.0, 2.0, 2.0], [0.0, 0.0, 1.0, 2.0], [0.0, 0.0, 2.0, 1.0]]], device=device, dtype=dtype).expand(3, -1, -1) points = torch.tensor([[[0.0, 0.0], [1.0, 1.0]]], device=device, dtype=dtype).expand(3, -1, -1) aug = K.AugmentationSequential( K.RandomCrop((3, 3), padding=1, cropping_mode='resample', fill=0), K.RandomAffine((360., 360.), p=1.), data_keys=["input", "mask", "bbox_xyxy", "keypoints"], ) reproducibility_test((input, input, bbox, points), aug) _params = aug.forward_parameters(input.shape) # specifying the crops allows us to compute by hand the expected outputs _params[0].data['src'] = torch.tensor( [ [[1.0, 2.0], [3.0, 2.0], [3.0, 4.0], [1.0, 4.0]], [[1.0, 1.0], [3.0, 1.0], [3.0, 3.0], [1.0, 3.0]], [[2.0, 0.0], [4.0, 0.0], [4.0, 2.0], [2.0, 2.0]], ], device=_params[0].data['src'].device, dtype=_params[0].data['src'].dtype, ) expected_out_bbox = torch.tensor( [ [[1.0, 0.0, 2.0, 1.0], [0.0, -1.0, 1.0, 1.0], [0.0, -1.0, 2.0, 0.0]], [[1.0, 1.0, 2.0, 2.0], [0.0, 0.0, 1.0, 2.0], [0.0, 0.0, 2.0, 1.0]], [[0.0, 2.0, 1.0, 3.0], [-1.0, 1.0, 0.0, 3.0], [-1.0, 1.0, 1.0, 2.0]], ], device=device, dtype=dtype, ) expected_out_points = torch.tensor( [[[0.0, -1.0], [1.0, 0.0]], [[0.0, 0.0], [1.0, 1.0]], [[-1.0, 1.0], [0.0, 2.0]]], device=device, dtype=dtype) out = aug(input, input, bbox, points, params=_params) assert out[0].shape == (3, 3, 3, 3) assert_close(out[0], out[1], atol=1e-4, rtol=1e-4) assert out[2].shape == bbox.shape assert_close(out[2], expected_out_bbox, atol=1e-4, rtol=1e-4) assert out[3].shape == points.shape assert_close(out[3], expected_out_points, atol=1e-4, rtol=1e-4) out_inv = aug.inverse(*out) assert out_inv[0].shape == input.shape assert_close(out_inv[0], out_inv[1], atol=1e-4, rtol=1e-4) assert out_inv[2].shape == bbox.shape assert_close(out_inv[2], bbox, atol=1e-4, rtol=1e-4) assert out_inv[3].shape == points.shape assert_close(out_inv[3], points, atol=1e-4, rtol=1e-4)
def __init__( self, image_size, latent_dim=512, style_depth=8, network_capacity=16, transparent=False, fp16=False, cl_reg=False, augment_fn=None, steps=1, lr=1e-4, fq_layers=[], fq_dict_size=256, attn_layers=[], ): super().__init__() self.lr = lr self.steps = steps self.ema_updater = EMA(0.995) self.S = StyleVectorizer(latent_dim, style_depth) self.G = Generator(image_size, latent_dim, network_capacity, transparent=transparent, attn_layers=attn_layers) self.D = Discriminator( image_size, network_capacity, fq_layers=fq_layers, fq_dict_size=fq_dict_size, attn_layers=attn_layers, transparent=transparent, ) self.SE = StyleVectorizer(latent_dim, style_depth) self.GE = Generator(image_size, latent_dim, network_capacity, transparent=transparent, attn_layers=attn_layers) set_requires_grad(self.SE, False) set_requires_grad(self.GE, False) generator_params = list(self.G.parameters()) + list( self.S.parameters()) self.G_opt = DiffGrad(generator_params, lr=self.lr, betas=(0.5, 0.9)) self.D_opt = DiffGrad(self.D.parameters(), lr=self.lr, betas=(0.5, 0.9)) self._init_weights() self.reset_parameter_averaging() self.cuda() if fp16: (self.S, self.G, self.D, self.SE, self.GE), (self.G_opt, self.D_opt) = amp.initialize( [self.S, self.G, self.D, self.SE, self.GE], [self.G_opt, self.D_opt], opt_level="O2") # experimental contrastive loss discriminator regularization if augment_fn is not None: self.augment_fn = augment_fn else: self.augment_fn = nn.Sequential( nn.ReflectionPad2d(int((sqrt(2) - 1) * image_size / 4)), RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.7), augs.RandomGrayscale(p=0.2), augs.RandomHorizontalFlip(), RandomApply(augs.RandomAffine(degrees=0, translate=(0.25, 0.25), shear=(15, 15)), p=0.3), RandomApply(nn.Sequential( augs.RandomRotation(180), augs.CenterCrop(size=(image_size, image_size))), p=0.2), augs.RandomResizedCrop(size=(image_size, image_size)), RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), RandomApply(augs.RandomErasing(), p=0.1), ) self.D_cl = (ContrastiveLearner(self.D, image_size, augment_fn=self.augment_fn, fp16=fp16, hidden_layer="flatten") if cl_reg else None)
opt = optim.Adam({}, clip_args={"clip_norm": 10.0}) elbo = infer.Trace_ELBO() svi = infer.SVI( transforming_template_mnist, transforming_template_encoder, opt, loss=elbo ) files = glob.glob("runs/*") for f in files: os.remove(f) tb = SummaryWriter("runs/") x_train = x_train.expand(516, 1, 28, 28) x_train = F.pad(x_train, (6, 6, 6, 6)) for i in range(20000): x_rot = augmentation.RandomAffine(40.0)( x_train ) # randomly translate the image by up to 40 degrees l = svi.step(x_rot, transforms, cond=True, grid_size=40, encoder=encoder) tb.add_scalar("loss", l, i) if (i % 100) == 0: print(i, l) x_rot = x_rot[:32] tb.add_image("originals", torchvision.utils.make_grid(x_rot), i) bwd_trace = poutine.trace(transforming_template_encoder).get_trace( x_rot, transforms, cond=True, grid_size=40, encoder=encoder ) fwd_trace = poutine.trace( poutine.replay(transforming_template_mnist, trace=bwd_trace) ).get_trace(x_rot, transforms, cond=False, grid_size=40, encoder=encoder) recon = fwd_trace.nodes["pixels"]["fn"].mean
g_ema = Generator( args.size, args.latent_size, args.n_mlp, channel_multiplier=args.channel_multiplier, constant_input=args.constant_input, ).to(device) g_ema.requires_grad_(False) g_ema.eval() accumulate(g_ema, generator, 0) augment_fn = nn.Sequential( nn.ReflectionPad2d(int((math.sqrt(2) - 1) * args.size / 4)), # zoom out augs.RandomHorizontalFlip(), RandomApply(augs.RandomAffine(degrees=0, translate=(0.25, 0.25), shear=(15, 15)), p=0.2), RandomApply(augs.RandomRotation(180), p=0.2), augs.RandomResizedCrop(size=(args.size, args.size), scale=(1, 1), ratio=(1, 1)), RandomApply(augs.RandomResizedCrop(size=(args.size, args.size), scale=(0.5, 0.9)), p=0.1), # zoom in RandomApply(augs.RandomErasing(), p=0.1), ) contrast_learner = ( ContrastiveLearner(discriminator, args.size, augment_fn=augment_fn, hidden_layer=(-1, 0)) if args.contrastive > 0 else None ) g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1) d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1) g_optim = th.optim.Adam(
def main(argv=None): # CLI parser = argparse.ArgumentParser() parser.add_argument("name", help="Name of the experiment") parser.add_argument( "-a", "--augment", action="store_true", help="If True, we apply augmentations", ) parser.add_argument("-b", "--batch-size", type=int, default=16, help="Batch size") parser.add_argument( "--b1", type=float, default=0.5, help="Adam optimizer hyperparamter", ) parser.add_argument( "--b2", type=float, default=0.999, help="Adam optimizer hyperparamter", ) parser.add_argument( "-d", "--device", type=str, default="cpu", choices=["cpu", "cuda"], help="Device to use", ) parser.add_argument( "--eval-frequency", type=int, default=400, help="Generate generator images every `eval_frequency` epochs", ) parser.add_argument( "--latent-dim", type=int, default=100, help="Dimensionality of the random noise", ) parser.add_argument("--lr", type=float, default=0.0002, help="Learning rate") parser.add_argument( "--ndf", type=int, default=32, help="Number of discriminator feature maps (after first convolution)", ) parser.add_argument( "--ngf", type=int, default=32, help= "Number of generator feature maps (before last transposed convolution)", ) parser.add_argument( "-n", "--n-epochs", type=int, default=200, help="Number of training epochs", ) parser.add_argument( "--mosaic-size", type=int, default=10, help="Size of the side of the rectangular mosaic", ) parser.add_argument( "-p", "--prob", type=float, default=0.9, help="Probability of applying an augmentation", ) args = parser.parse_args(argv) args_d = vars(args) print(args) img_size = 128 # Additional parameters device = torch.device(args.device) mosaic_kwargs = {"nrow": args.mosaic_size, "normalize": True} n_mosaic_cells = args.mosaic_size * args.mosaic_size sample_showcase_ix = ( 0 # this one will be used to demonstrate the augmentations ) augment_module = torch.nn.Sequential( K.RandomAffine(degrees=0, translate=(1 / 8, 1 / 8), p=args.prob), K.RandomErasing((0.0, 0.5), p=args.prob), ) # Loss function adversarial_loss = torch.nn.BCELoss() # Initialize generator and discriminator generator = Generator(latent_dim=args.latent_dim, ngf=args.ngf) discriminator = Discriminator( ndf=args.ndf, augment_module=augment_module if args.augment else None) generator.to(device) discriminator.to(device) # Initialize weights generator.apply(init_weights_) discriminator.apply(init_weights_) # Configure data loader data_path = pathlib.Path("data") tform = transforms.Compose([ transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) dataset = DatasetImages( data_path, transform=tform, ) dataloader = DataLoader( dataset, batch_size=args.batch_size, shuffle=True, ) # Optimizers optimizer_G = torch.optim.Adam(generator.parameters(), lr=args.lr, betas=(args.b1, args.b2)) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.b1, args.b2)) # Output path and metadata output_path = pathlib.Path("outputs") / args.name output_path.mkdir(exist_ok=True, parents=True) # Add other parameters (not included in CLI) args_d["time"] = datetime.now() args_d["kornia"] = str(augment_module) # Prepare tensorboard writer writer = SummaryWriter(output_path) # Log hyperparameters as text writer.add_text( "hyperparameter", pprint.pformat(args_d).replace( "\n", " \n"), # markdown needs 2 spaces before newline 0, ) # Log true data writer.add_image( "true_data", make_grid(torch.stack([dataset[i] for i in range(n_mosaic_cells)]), **mosaic_kwargs), 0, ) # Log augmented data batch_showcase = dataset[sample_showcase_ix][None, ...].repeat( n_mosaic_cells, 1, 1, 1) batch_showcase_aug = discriminator.augment_module(batch_showcase) writer.add_image("augmentations", make_grid(batch_showcase_aug, **mosaic_kwargs), 0) # Prepate evaluation noise z_eval = torch.randn(n_mosaic_cells, args.latent_dim).to(device) for epoch in tqdm(range(args.n_epochs)): for i, imgs in enumerate(dataloader): n_samples, *_ = imgs.shape batches_done = epoch * len(dataloader) + i # Adversarial ground truths valid = 0.9 * torch.ones( n_samples, 1, device=device, dtype=torch.float32) fake = torch.zeros(n_samples, 1, device=device, dtype=torch.float32) # D preparation optimizer_D.zero_grad() # D loss on reals real_imgs = imgs.to(device) d_x = discriminator(real_imgs) real_loss = adversarial_loss(d_x, valid) real_loss.backward() # D loss on fakes z = torch.randn(n_samples, args.latent_dim).to(device) gen_imgs = generator(z) d_g_z1 = discriminator(gen_imgs.detach()) fake_loss = adversarial_loss(d_g_z1, fake) fake_loss.backward() optimizer_D.step() # we called backward twice, the result is a sum # G preparation optimizer_G.zero_grad() # G loss d_g_z2 = discriminator(gen_imgs) g_loss = adversarial_loss(d_g_z2, valid) g_loss.backward() optimizer_G.step() # Logging if batches_done % 50 == 0: writer.add_scalar("d_x", d_x.mean().item(), batches_done) writer.add_scalar("d_g_z1", d_g_z1.mean().item(), batches_done) writer.add_scalar("d_g_z2", d_g_z2.mean().item(), batches_done) writer.add_scalar("D_loss", (real_loss + fake_loss).item(), batches_done) writer.add_scalar("G_loss", g_loss.item(), batches_done) if epoch % args.eval_frequency == 0 and i == 0: generator.eval() discriminator.eval() # Generate fake images gen_imgs_eval = generator(z_eval) # Generate nice mosaic writer.add_image( "fake", make_grid(gen_imgs_eval.data, **mosaic_kwargs), batches_done, ) # Save checkpoint (and potentially overwrite an existing one) torch.save(generator, output_path / "model.pt") # Make sure generator and discriminator in the training mode generator.train() discriminator.train()