def test_param(self, angle, direction, border_type, resample, return_transform, same_on_batch, device, dtype): _angle = (angle if isinstance(angle, (float, int, list, tuple)) else nn.Parameter(angle.clone().to(device=device, dtype=dtype))) _direction = (direction if isinstance(direction, (list, tuple)) else nn.Parameter(direction.clone().to(device=device, dtype=dtype))) torch.manual_seed(0) input = torch.randint(255, (2, 3, 10, 10), device=device, dtype=dtype) / 255.0 aug = RandomMotionBlur((3, 3), _angle, _direction, border_type, resample, return_transform, same_on_batch, p=1.0) if return_transform: output, _ = aug(input) else: output = aug(input) if len(list(aug.parameters())) != 0: mse = nn.MSELoss() opt = torch.optim.SGD(aug.parameters(), lr=0.1) loss = mse(output, torch.ones_like(output) * 2) loss.backward() opt.step() if not isinstance(angle, (float, int, list, tuple)): assert isinstance(aug._param_generator.angle, torch.Tensor) if resample == 'nearest' and aug._param_generator.angle.is_cuda: # grid_sample in nearest mode and cuda device returns nan than 0 pass elif resample == 'nearest' or torch.all( aug._param_generator.angle._grad == 0.0): # grid_sample will return grad = 0 for resample nearest # https://discuss.pytorch.org/t/autograd-issue-with-f-grid-sample/76894 assert (angle.to(device=device, dtype=dtype) - aug._param_generator.angle.data).sum() == 0 else: # Assert if param not updated assert (angle.to(device=device, dtype=dtype) - aug._param_generator.angle.data).sum() != 0 if not isinstance(direction, (list, tuple)): assert isinstance(aug._param_generator.direction, torch.Tensor) if torch.all(aug._param_generator.direction._grad == 0.0): # grid_sample will return grad = 0 for resample nearest # https://discuss.pytorch.org/t/autograd-issue-with-f-grid-sample/76894 assert (direction.to(device=device, dtype=dtype) - aug._param_generator.direction.data).sum() == 0 else: # Assert if param not updated assert (direction.to(device=device, dtype=dtype) - aug._param_generator.direction.data).sum() != 0
def test_smoke(self, device): f = RandomMotionBlur(kernel_size=(3, 5), angle=(10, 30), direction=0.5) repr = ( "RandomMotionBlur(kernel_size=(3, 5), angle=tensor([10., 30.]), direction=tensor([-0.5000, 0.5000]), " "border_type='constant', p=0.5, p_batch=1.0, same_on_batch=False, return_transform=False)" ) assert str(f) == repr
def test_random_motion_blur(self, same_on_batch, return_transform, p, device, dtype): f = RandomMotionBlur(kernel_size=(3, 5), angle=(10, 30), direction=0.5, same_on_batch=same_on_batch, return_transform=return_transform, p=p) torch.manual_seed(0) batch_size = 2 input = torch.randn(1, 3, 5, 6).repeat(batch_size, 1, 1, 1) output = f(input) if return_transform: assert len( output ) == 2, f"must return a length 2 tuple if return_transform is True. Got {len(output)}." identity = kornia.eye_like(3, input) output, mat = output assert_allclose(mat, identity, rtol=1e-4, atol=1e-4) if same_on_batch: assert_allclose(output[0], output[1], rtol=1e-4, atol=1e-4) elif p == 0: assert_allclose(output, input, rtol=1e-4, atol=1e-4) else: assert not torch.allclose( output[0], output[1], rtol=1e-4, atol=1e-4) assert output.shape == torch.Size([batch_size, 3, 5, 6])
def test_gradcheck(self, device): torch.manual_seed(0) # for random reproductibility inp = torch.rand((1, 3, 11, 7)).to(device) inp = utils.tensor_to_gradcheck_var(inp) # to var # TODO: Gradcheck for param random gen failed. Suspect get_motion_kernel2d issue. params = { 'batch_prob': torch.tensor([True]), 'ksize_factor': torch.tensor([31]), 'angle_factor': torch.tensor([30.]), 'direction_factor': torch.tensor([-0.5]), 'border_type': torch.tensor([0]), } assert gradcheck(RandomMotionBlur( kernel_size=3, angle=(10, 30), direction=(-0.5, 0.5), p=1.0), (inp, params), raise_exception=True)
def test_against_functional(self, input_shape): input = torch.randn(*input_shape) f = RandomMotionBlur(kernel_size=(3, 5), angle=(10, 30), direction=0.5, p=1.0) output = f(input) expected = motion_blur( input, f._params['ksize_factor'].unique().item(), f._params['angle_factor'], f._params['direction_factor'], f.border_type.name.lower(), ) assert_allclose(output, expected, rtol=1e-4, atol=1e-4)