Exemple #1
0
    def test_batch_random_affine_3d(self, device, dtype):
        # TODO(jian): cuda and fp64
        if "cuda" in str(device) and dtype == torch.float64:
            pytest.skip(
                "AssertionError: assert tensor(False, device='cuda:0')")

        f = RandomAffine3D((0, 0, 0), p=1.0,
                           return_transform=True)  # No rotation
        tensor = torch.tensor(
            [[[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]]]],
            device=device,
            dtype=dtype)  # 1 x 1 x 1 x 3 x 3

        expected = torch.tensor(
            [[[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]]]],
            device=device,
            dtype=dtype)  # 1 x 1 x 1 x 3 x 3

        expected_transform = torch.tensor(
            [[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0],
              [0.0, 0.0, 0.0, 1.0]]],
            device=device,
            dtype=dtype,
        )  # 1 x 4 x 4

        tensor = tensor.repeat(5, 3, 1, 1, 1)  # 5 x 3 x 3 x 3 x 3
        expected = expected.repeat(5, 3, 1, 1, 1)  # 5 x 3 x 3 x 3 x 3
        expected_transform = expected_transform.repeat(5, 1, 1)  # 5 x 4 x 4

        assert (f(tensor)[0] == expected).all()
        assert (f(tensor)[1] == expected_transform).all()
Exemple #2
0
    def test_param(
        self, degrees, translate, scale, shear, resample, align_corners, return_transform, same_on_batch, device, dtype
    ):

        _degrees = (
            degrees
            if isinstance(degrees, (int, float, list, tuple))
            else nn.Parameter(degrees.clone().to(device=device, dtype=dtype))
        )
        _translate = (
            translate
            if isinstance(translate, (int, float, list, tuple))
            else nn.Parameter(translate.clone().to(device=device, dtype=dtype))
        )
        _scale = (
            scale
            if isinstance(scale, (int, float, list, tuple))
            else nn.Parameter(scale.clone().to(device=device, dtype=dtype))
        )
        _shear = (
            shear
            if isinstance(shear, (int, float, list, tuple))
            else nn.Parameter(shear.clone().to(device=device, dtype=dtype))
        )

        torch.manual_seed(0)
        input = torch.randint(255, (2, 3, 10, 10, 10), device=device, dtype=dtype) / 255.0
        aug = RandomAffine3D(
            _degrees,
            _translate,
            _scale,
            _shear,
            resample,
            align_corners=align_corners,
            return_transform=return_transform,
            same_on_batch=same_on_batch,
            p=1.0,
        )

        if return_transform:
            output, _ = aug(input)
        else:
            output = aug(input)

        if len(list(aug.parameters())) != 0:
            mse = nn.MSELoss()
            opt = torch.optim.SGD(aug.parameters(), lr=10)
            loss = mse(output, torch.ones_like(output) * 2)  # to ensure that a big loss value could be obtained
            loss.backward()
            opt.step()

        if not isinstance(degrees, (int, float, list, tuple)):
            assert isinstance(aug.degrees, torch.Tensor)
            # Assert if param not updated
            if resample == 'nearest' and aug.degrees.is_cuda:
                # grid_sample in nearest mode and cuda device returns nan than 0
                pass
            elif resample == 'nearest' or torch.all(aug.degrees._grad == 0.0):
                # grid_sample will return grad = 0 for resample nearest
                # https://discuss.pytorch.org/t/autograd-issue-with-f-grid-sample/76894
                assert (degrees.to(device=device, dtype=dtype) - aug.degrees.data).sum() == 0
            else:
                assert (degrees.to(device=device, dtype=dtype) - aug.degrees.data).sum() != 0
        if not isinstance(translate, (int, float, list, tuple)):
            assert isinstance(aug.translate, torch.Tensor)
            # Assert if param not updated
            if resample == 'nearest' and aug.translate.is_cuda:
                # grid_sample in nearest mode and cuda device returns nan than 0
                pass
            elif resample == 'nearest' or torch.all(aug.translate._grad == 0.0):
                # grid_sample will return grad = 0 for resample nearest
                # https://discuss.pytorch.org/t/autograd-issue-with-f-grid-sample/76894
                assert (translate.to(device=device, dtype=dtype) - aug.translate.data).sum() == 0
            else:
                assert (translate.to(device=device, dtype=dtype) - aug.translate.data).sum() != 0
        if not isinstance(scale, (int, float, list, tuple)):
            assert isinstance(aug.scale, torch.Tensor)
            # Assert if param not updated
            if resample == 'nearest' and aug.scale.is_cuda:
                # grid_sample in nearest mode and cuda device returns nan than 0
                pass
            elif resample == 'nearest' or torch.all(aug.scale._grad == 0.0):
                # grid_sample will return grad = 0 for resample nearest
                # https://discuss.pytorch.org/t/autograd-issue-with-f-grid-sample/76894
                assert (scale.to(device=device, dtype=dtype) - aug.scale.data).sum() == 0
            else:
                assert (scale.to(device=device, dtype=dtype) - aug.scale.data).sum() != 0
        if not isinstance(shear, (int, float, list, tuple)):
            assert isinstance(aug.shears, torch.Tensor)
            # Assert if param not updated
            if resample == 'nearest' and aug.shears.is_cuda:
                # grid_sample in nearest mode and cuda device returns nan than 0
                pass
            elif resample == 'nearest' or torch.all(aug.shears._grad == 0.0):
                # grid_sample will return grad = 0 for resample nearest
                # https://discuss.pytorch.org/t/autograd-issue-with-f-grid-sample/76894
                assert (shear.to(device=device, dtype=dtype) - aug.shears.data).sum() == 0
            else:
                assert (shear.to(device=device, dtype=dtype) - aug.shears.data).sum() != 0