class TestVideoSequential: @pytest.mark.parametrize('shape', [(3, 4), (2, 3, 4), (2, 3, 5, 6), (2, 3, 4, 5, 6, 7)]) @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"]) def test_exception(self, shape, data_format, device, dtype): aug_list = K.VideoSequential(K.ColorJitter(0.1, 0.1, 0.1, 0.1), data_format=data_format, same_on_frame=True) with pytest.raises(AssertionError): img = torch.randn(*shape, device=device, dtype=dtype) aug_list(img) @pytest.mark.parametrize( 'augmentation', [ K.RandomAffine(360, p=1.0), K.CenterCrop((3, 3), p=1.0), K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomCrop((5, 5), p=1.0), K.RandomErasing(p=1.0), K.RandomGrayscale(p=1.0), K.RandomHorizontalFlip(p=1.0), K.RandomVerticalFlip(p=1.0), K.RandomPerspective(p=1.0), K.RandomResizedCrop((5, 5), p=1.0), K.RandomRotation(360.0, p=1.0), K.RandomSolarize(p=1.0), K.RandomPosterize(p=1.0), K.RandomSharpness(p=1.0), K.RandomEqualize(p=1.0), K.RandomMotionBlur(3, 35.0, 0.5, p=1.0), K.Normalize(torch.tensor([0.5, 0.5, 0.5]), torch.tensor([0.5, 0.5, 0.5]), p=1.0), K.Denormalize(torch.tensor([0.5, 0.5, 0.5]), torch.tensor([0.5, 0.5, 0.5]), p=1.0), ], ) @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"]) def test_augmentation(self, augmentation, data_format, device, dtype): input = torch.randint(255, (1, 3, 3, 5, 6), device=device, dtype=dtype).repeat(2, 1, 1, 1, 1) / 255.0 torch.manual_seed(21) aug_list = K.VideoSequential(augmentation, data_format=data_format, same_on_frame=True) reproducibility_test(input, aug_list) @pytest.mark.parametrize( 'augmentations', [ [ K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomAffine(360, p=1.0) ], [ K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0) ], [K.RandomAffine(360, p=1.0), kornia.color.BgrToRgb()], [ K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=0.0), K.RandomAffine(360, p=0.0) ], [K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=0.0)], [K.RandomAffine(360, p=0.0)], [ K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0), K.RandomAffine(360, p=1.0), K.RandomMixUp(p=1.0) ], ], ) @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"]) @pytest.mark.parametrize('random_apply', [1, (1, 1), (1, ), 10, True, False]) def test_same_on_frame(self, augmentations, data_format, random_apply, device, dtype): aug_list = K.VideoSequential(*augmentations, data_format=data_format, same_on_frame=True, random_apply=random_apply) if data_format == 'BCTHW': input = torch.randn(2, 3, 1, 5, 6, device=device, dtype=dtype).repeat(1, 1, 4, 1, 1) output = aug_list(input) if aug_list.return_label: output, _ = output assert (output[:, :, 0] == output[:, :, 1]).all() assert (output[:, :, 1] == output[:, :, 2]).all() assert (output[:, :, 2] == output[:, :, 3]).all() if data_format == 'BTCHW': input = torch.randn(2, 1, 3, 5, 6, device=device, dtype=dtype).repeat(1, 4, 1, 1, 1) output = aug_list(input) if aug_list.return_label: output, _ = output assert (output[:, 0] == output[:, 1]).all() assert (output[:, 1] == output[:, 2]).all() assert (output[:, 2] == output[:, 3]).all() reproducibility_test(input, aug_list) @pytest.mark.parametrize( 'augmentations', [ [K.RandomAffine(360, p=1.0)], [K.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0)], [ K.RandomAffine(360, p=0.0), K.ImageSequential(K.RandomAffine(360, p=0.0)) ], ], ) @pytest.mark.parametrize('data_format', ["BCTHW", "BTCHW"]) def test_against_sequential(self, augmentations, data_format, device, dtype): aug_list_1 = K.VideoSequential(*augmentations, data_format=data_format, same_on_frame=False) aug_list_2 = torch.nn.Sequential(*augmentations) if data_format == 'BCTHW': input = torch.randn(2, 3, 1, 5, 6, device=device, dtype=dtype).repeat(1, 1, 4, 1, 1) if data_format == 'BTCHW': input = torch.randn(2, 1, 3, 5, 6, device=device, dtype=dtype).repeat(1, 4, 1, 1, 1) torch.manual_seed(0) output_1 = aug_list_1(input) torch.manual_seed(0) if data_format == 'BCTHW': input = input.transpose(1, 2) output_2 = aug_list_2(input.reshape(-1, 3, 5, 6)) output_2 = output_2.view(2, 4, 3, 5, 6) if data_format == 'BCTHW': output_2 = output_2.transpose(1, 2) assert (output_1 == output_2).all(), dict(aug_list_1._params) @pytest.mark.jit @pytest.mark.skip(reason="turn off due to Union Type") def test_jit(self, device, dtype): B, C, D, H, W = 2, 3, 5, 4, 4 img = torch.ones(B, C, D, H, W, device=device, dtype=dtype) op = K.VideoSequential(K.ColorJitter(0.1, 0.1, 0.1, 0.1), same_on_frame=True) op_jit = torch.jit.script(op) assert_close(op(img), op_jit(img))
def __init__( self, net, image_size, hidden_layer=-2, project_hidden=True, project_dim=128, augment_both=True, use_nt_xent_loss=False, augment_fn=None, use_bilinear=False, use_momentum=False, momentum_value=0.999, key_encoder=None, temperature=0.1, batch_size=128, ): super().__init__() self.net = OutputHiddenLayer(net, layer=hidden_layer) DEFAULT_AUG = nn.Sequential( # RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), # augs.RandomGrayscale(p=0.2), augs.RandomHorizontalFlip(), augs.RandomVerticalFlip(), augs.RandomSolarize(), augs.RandomPosterize(), augs.RandomSharpness(), augs.RandomEqualize(), augs.RandomRotation(degrees=8.0), RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), augs.RandomResizedCrop((image_size, image_size), p=0.1), ) self.b = batch_size self.h = image_size self.w = image_size self.augment = default(augment_fn, DEFAULT_AUG) self.augment_both = augment_both self.temperature = temperature self.use_nt_xent_loss = use_nt_xent_loss self.project_hidden = project_hidden self.projection = None self.project_dim = project_dim self.use_bilinear = use_bilinear self.bilinear_w = None self.use_momentum = use_momentum self.ema_updater = EMA(momentum_value) self.key_encoder = key_encoder # for accumulating queries and keys across calls self.queries = None self.keys = None random_data = ( ( torch.randn(1, 3, image_size, image_size), torch.randn(1, 3, image_size, image_size), torch.randn(1, 3, image_size, image_size), ), torch.tensor([1]), ) # send a mock image tensor to instantiate parameters self.forward(random_data)