def test_batch_random_affine_3d(self, device, dtype): # TODO(jian): cuda and fp64 if "cuda" in str(device) and dtype == torch.float64: pytest.skip( "AssertionError: assert tensor(False, device='cuda:0')") f = RandomAffine3D((0, 0, 0), p=1.0, return_transform=True) # No rotation tensor = torch.tensor( [[[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]]]], device=device, dtype=dtype) # 1 x 1 x 1 x 3 x 3 expected = torch.tensor( [[[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]]]], device=device, dtype=dtype) # 1 x 1 x 1 x 3 x 3 expected_transform = torch.tensor( [[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]], device=device, dtype=dtype, ) # 1 x 4 x 4 tensor = tensor.repeat(5, 3, 1, 1, 1) # 5 x 3 x 3 x 3 x 3 expected = expected.repeat(5, 3, 1, 1, 1) # 5 x 3 x 3 x 3 x 3 expected_transform = expected_transform.repeat(5, 1, 1) # 5 x 4 x 4 assert (f(tensor)[0] == expected).all() assert (f(tensor)[1] == expected_transform).all()
def test_param( self, degrees, translate, scale, shear, resample, align_corners, return_transform, same_on_batch, device, dtype ): _degrees = ( degrees if isinstance(degrees, (int, float, list, tuple)) else nn.Parameter(degrees.clone().to(device=device, dtype=dtype)) ) _translate = ( translate if isinstance(translate, (int, float, list, tuple)) else nn.Parameter(translate.clone().to(device=device, dtype=dtype)) ) _scale = ( scale if isinstance(scale, (int, float, list, tuple)) else nn.Parameter(scale.clone().to(device=device, dtype=dtype)) ) _shear = ( shear if isinstance(shear, (int, float, list, tuple)) else nn.Parameter(shear.clone().to(device=device, dtype=dtype)) ) torch.manual_seed(0) input = torch.randint(255, (2, 3, 10, 10, 10), device=device, dtype=dtype) / 255.0 aug = RandomAffine3D( _degrees, _translate, _scale, _shear, resample, align_corners=align_corners, return_transform=return_transform, same_on_batch=same_on_batch, p=1.0, ) if return_transform: output, _ = aug(input) else: output = aug(input) if len(list(aug.parameters())) != 0: mse = nn.MSELoss() opt = torch.optim.SGD(aug.parameters(), lr=10) loss = mse(output, torch.ones_like(output) * 2) # to ensure that a big loss value could be obtained loss.backward() opt.step() if not isinstance(degrees, (int, float, list, tuple)): assert isinstance(aug.degrees, torch.Tensor) # Assert if param not updated if resample == 'nearest' and aug.degrees.is_cuda: # grid_sample in nearest mode and cuda device returns nan than 0 pass elif resample == 'nearest' or torch.all(aug.degrees._grad == 0.0): # grid_sample will return grad = 0 for resample nearest # https://discuss.pytorch.org/t/autograd-issue-with-f-grid-sample/76894 assert (degrees.to(device=device, dtype=dtype) - aug.degrees.data).sum() == 0 else: assert (degrees.to(device=device, dtype=dtype) - aug.degrees.data).sum() != 0 if not isinstance(translate, (int, float, list, tuple)): assert isinstance(aug.translate, torch.Tensor) # Assert if param not updated if resample == 'nearest' and aug.translate.is_cuda: # grid_sample in nearest mode and cuda device returns nan than 0 pass elif resample == 'nearest' or torch.all(aug.translate._grad == 0.0): # grid_sample will return grad = 0 for resample nearest # https://discuss.pytorch.org/t/autograd-issue-with-f-grid-sample/76894 assert (translate.to(device=device, dtype=dtype) - aug.translate.data).sum() == 0 else: assert (translate.to(device=device, dtype=dtype) - aug.translate.data).sum() != 0 if not isinstance(scale, (int, float, list, tuple)): assert isinstance(aug.scale, torch.Tensor) # Assert if param not updated if resample == 'nearest' and aug.scale.is_cuda: # grid_sample in nearest mode and cuda device returns nan than 0 pass elif resample == 'nearest' or torch.all(aug.scale._grad == 0.0): # grid_sample will return grad = 0 for resample nearest # https://discuss.pytorch.org/t/autograd-issue-with-f-grid-sample/76894 assert (scale.to(device=device, dtype=dtype) - aug.scale.data).sum() == 0 else: assert (scale.to(device=device, dtype=dtype) - aug.scale.data).sum() != 0 if not isinstance(shear, (int, float, list, tuple)): assert isinstance(aug.shears, torch.Tensor) # Assert if param not updated if resample == 'nearest' and aug.shears.is_cuda: # grid_sample in nearest mode and cuda device returns nan than 0 pass elif resample == 'nearest' or torch.all(aug.shears._grad == 0.0): # grid_sample will return grad = 0 for resample nearest # https://discuss.pytorch.org/t/autograd-issue-with-f-grid-sample/76894 assert (shear.to(device=device, dtype=dtype) - aug.shears.data).sum() == 0 else: assert (shear.to(device=device, dtype=dtype) - aug.shears.data).sum() != 0