def test_random_rotation(self, device): # This is included in doctest torch.manual_seed(0) # for random reproductibility f = RandomRotation3D(degrees=45.0, return_transform=True) f1 = RandomRotation3D(degrees=45.0) input = torch.tensor([[[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]], [[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]], [[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]]]) # 3 x 4 x 4 input = input.to(device) expected = torch.tensor([[[[[9.9412e-01, 0.0000e+00, 8.5407e-03, 1.9535e+00], [1.7328e-05, 3.8945e-03, 1.5617e-02, 1.0553e-04], [0.0000e+00, 9.8295e-01, 1.9789e+00, 4.9531e-02], [0.0000e+00, 0.0000e+00, 9.7034e-01, 1.9548e+00]], [[9.6646e-01, 0.0000e+00, 4.3866e-02, 1.9559e+00], [1.1586e-02, 0.0000e+00, 1.0260e-04, 0.0000e+00], [4.4472e-03, 9.9659e-01, 1.9833e+00, 3.4181e-05], [0.0000e+00, 7.8456e-03, 9.9959e-01, 1.9956e+00]], [[9.3772e-01, 0.0000e+00, 7.8179e-02, 1.8983e+00], [2.2707e-02, 0.0000e+00, 9.6477e-04, 2.2624e-02], [2.1575e-02, 9.9975e-01, 1.9243e+00, 0.0000e+00], [3.1300e-04, 3.2790e-02, 1.0249e+00, 1.9474e+00]]]]]) expected = expected.to(device) expected_transform = torch.tensor([[[0.7168, 0.5830, 0.3825, -1.1651], [-0.5853, 0.8012, -0.1242, 1.0699], [-0.3789, -0.1349, 0.9155, 0.7079], [0.0000, 0.0000, 0.0000, 1.0000]]]) expected_transform = expected_transform.to(device) expected_2 = torch.tensor([[[[[0.5337, 0.0176, 0.0066, 0.2952], [0.1152, 0.1210, 0.7456, 1.4092], [0.0000, 0.0000, 0.3873, 0.5115], [0.0000, 0.0000, 0.0000, 0.0000]], [[0.1647, 1.5763, 0.5870, 0.0000], [0.0000, 0.2845, 0.0000, 0.0000], [0.0000, 0.5602, 0.6448, 0.2305], [0.3809, 1.3558, 1.5374, 1.5583]], [[0.0000, 0.0000, 0.0000, 0.0000], [0.4242, 0.0000, 0.0000, 0.0000], [1.4400, 0.1584, 0.0000, 0.0000], [0.0946, 0.0000, 0.0000, 0.0000]]]]]) expected_2 = expected_2.to(device) out, mat = f(input) assert_allclose(out, expected, rtol=1e-6, atol=1e-4) assert_allclose(mat, expected_transform, rtol=1e-6, atol=1e-4) assert_allclose(f1(input), expected_2, rtol=1e-6, atol=1e-4)
def test_sequential(self, device, dtype): torch.manual_seed(24) # for random reproductibility f = nn.Sequential( RandomRotation3D(torch.tensor([-45.0, 90]), return_transform=True), RandomRotation3D(10.4, return_transform=True), ) f1 = nn.Sequential( RandomRotation3D(torch.tensor([-45.0, 90]), return_transform=True), RandomRotation3D(10.4), ) input = torch.tensor([[[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]], [[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]], [[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]]], device=device, dtype=dtype) # 3 x 4 x 4 expected = torch.tensor([[[[[0.3431, 0.1239, 0.0000, 1.0348], [0.0000, 0.2035, 0.1139, 0.1770], [0.0789, 0.9057, 1.7780, 0.0000], [0.0000, 0.2286, 1.2498, 1.2643]], [[0.5460, 0.2131, 0.0000, 1.1453], [0.0000, 0.0899, 0.0000, 0.4293], [0.0797, 1.0193, 1.6677, 0.0000], [0.0000, 0.2458, 1.2765, 1.0920]], [[0.6322, 0.2614, 0.0000, 0.9207], [0.0000, 0.0037, 0.0000, 0.6551], [0.0689, 0.9251, 1.3442, 0.0000], [0.0000, 0.2449, 0.9856, 0.6862]]]]], device=device, dtype=dtype) expected_transform = torch.tensor([[[0.9857, -0.1686, -0.0019, 0.2762], [0.1668, 0.9739, 0.1538, -0.3650], [-0.0241, -0.1520, 0.9881, 0.2760], [0.0000, 0.0000, 0.0000, 1.0000]]], device=device, dtype=dtype) expected_transform_2 = torch.tensor( [[[0.2348, -0.1615, 0.9585, 0.4316], [0.1719, 0.9775, 0.1226, -0.3467], [-0.9567, 0.1360, 0.2573, 1.9738], [0.0000, 0.0000, 0.0000, 1.0000]]], device=device, dtype=dtype) out, mat = f(input) _, mat_2 = f1(input) assert_allclose(out, expected, rtol=1e-6, atol=1e-4) assert_allclose(mat, expected_transform, rtol=1e-6, atol=1e-4) assert_allclose(mat_2, expected_transform_2, rtol=1e-6, atol=1e-4)
def test_sequential(self, device): torch.manual_seed(0) # for random reproductibility f = nn.Sequential( RandomRotation3D(torch.tensor([-45.0, 90]), return_transform=True), RandomRotation3D(10.4, return_transform=True), ) f1 = nn.Sequential( RandomRotation3D(torch.tensor([-45.0, 90]), return_transform=True), RandomRotation3D(10.4), ) input = torch.tensor([[[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]], [[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]], [[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]]]) # 3 x 4 x 4 input = input.to(device) expected = torch.tensor( [[[[[5.8413e-01, 3.1238e-01, 2.7060e-02, 4.1555e-01], [4.3787e-02, 8.6122e-02, 8.3644e-02, 6.6170e-01], [3.7727e-01, 8.5232e-01, 5.1960e-01, 9.1474e-02], [2.7198e-01, 6.9109e-01, 5.7987e-01, 1.7646e-01]], [[8.8109e-02, 1.6755e-01, 1.0494e-01, 3.5867e-02], [6.1397e-02, 2.6093e-01, 4.2951e-01, 3.5508e-01], [7.9497e-02, 5.5609e-01, 1.3745e+00, 6.2975e-01], [8.9196e-03, 3.6598e-01, 1.1553e+00, 1.2267e+00]], [[6.1910e-03, 1.4123e-02, 5.0222e-02, 6.4776e-03], [2.9561e-05, 1.1410e-01, 5.5537e-01, 2.2792e-01], [0.0000e+00, 5.5169e-02, 5.4580e-01, 4.4281e-01], [0.0000e+00, 2.9420e-03, 1.7743e-01, 3.0618e-01]]]]]) expected = expected.to(device) expected_transform = torch.tensor([[[0.4690, 0.4978, 0.7295, -1.3100], [-0.2616, 0.8673, -0.4236, 1.0961], [-0.8435, 0.0078, 0.5370, 1.5263], [0.0000, 0.0000, 0.0000, 1.0000]]]) expected_transform = expected_transform.to(device) expected_transform_2 = torch.tensor( [[[0.2207, 0.0051, 0.9753, -0.6914], [0.4092, 0.9072, -0.0974, -0.1240], [-0.8853, 0.4206, 0.1981, 1.4573], [0.0000, 0.0000, 0.0000, 1.0000]]]) expected_transform_2 = expected_transform_2.to(device) out, mat = f(input) _, mat_2 = f1(input) assert_allclose(out, expected, rtol=1e-6, atol=1e-4) assert_allclose(mat, expected_transform, rtol=1e-6, atol=1e-4) assert_allclose(mat_2, expected_transform_2, rtol=1e-6, atol=1e-4)
def test_sequential(self, device): torch.manual_seed(0) # for random reproductibility f = nn.Sequential( RandomRotation3D(torch.tensor([-45.0, 90]), return_transform=True), RandomRotation3D(10.4, return_transform=True), ) f1 = nn.Sequential( RandomRotation3D(torch.tensor([-45.0, 90]), return_transform=True), RandomRotation3D(10.4), ) input = torch.tensor([[[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]], [[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]], [[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]]]) # 3 x 4 x 4 input = input.to(device) expected = torch.tensor( [[[[[4.5604e-02, 4.6441e-03, 7.1205e-01, 7.3379e-01], [2.7580e-01, 3.3129e-02, 1.1986e-01, 5.7227e-01], [1.4329e-01, 4.0604e-02, 4.4003e-03, 1.3030e-01], [8.0267e-05, 9.0396e-03, 2.8991e-02, 4.0690e-03]], [[9.8822e-03, 4.4220e-02, 1.2963e-01, 3.9873e-02], [4.3757e-02, 3.4982e-01, 5.1378e-01, 8.6131e-02], [6.6809e-02, 5.5708e-01, 1.0904e+00, 4.6732e-01], [2.9877e-02, 1.9682e-01, 3.5764e-01, 1.0877e-01]], [[2.3905e-02, 2.1605e-01, 2.5145e-02, 3.3507e-04], [1.1453e-01, 1.1965e+00, 1.2008e+00, 3.6272e-01], [1.3368e-01, 5.4211e-01, 1.2059e+00, 1.0104e+00], [8.6762e-03, 6.7149e-02, 2.0946e-01, 2.7900e-01]]]]]) expected = expected.to(device) expected_transform = torch.tensor([[[0.8369, 0.0343, -0.5463, 0.7395], [-0.5104, 0.4091, -0.7563, 2.4083], [0.1976, 0.9118, 0.3599, -1.0240], [0.0000, 0.0000, 0.0000, 1.0000]]]) expected_transform = expected_transform.to(device) expected_transform_2 = torch.tensor( [[[0.9869, -0.1351, 0.0879, 0.1343], [0.1598, 0.7501, -0.6417, 0.7769], [0.0208, 0.6474, 0.7619, -0.7641], [0.0000, 0.0000, 0.0000, 1.0000]]]) expected_transform_2 = expected_transform_2.to(device) out, mat = f(input) _, mat_2 = f1(input) assert_allclose(out, expected, rtol=1e-6, atol=1e-4) assert_allclose(mat, expected_transform, rtol=1e-6, atol=1e-4) assert_allclose(mat_2, expected_transform_2, rtol=1e-6, atol=1e-4)
def test_random_rotation(self, device, dtype): # This is included in doctest torch.manual_seed(0) # for random reproductibility f = RandomRotation3D(degrees=45.0, return_transform=True) f1 = RandomRotation3D(degrees=45.0) input = torch.tensor([[[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]], [[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]], [[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]]], device=device, dtype=dtype) # 3 x 4 x 4 expected = torch.tensor([[[[[0.2771, 0.0000, 0.0036, 0.0000], [0.5751, 0.0183, 0.7505, 0.4702], [0.0262, 0.2591, 0.5776, 0.4764], [0.0000, 0.0093, 0.0000, 0.0393]], [[0.0000, 0.0000, 0.0583, 0.0222], [0.1665, 0.0000, 1.0424, 1.0224], [0.1296, 0.4846, 1.4200, 1.2287], [0.0078, 0.3851, 0.3965, 0.3612]], [[0.0000, 0.7704, 0.6704, 0.0000], [0.0000, 0.0332, 0.2414, 0.0524], [0.0000, 0.3349, 1.4545, 1.3689], [0.0000, 0.0312, 0.5874, 0.8702]]]]], device=device, dtype=dtype) expected_transform = torch.tensor([[[0.5784, 0.7149, -0.3929, -0.0471], [-0.3657, 0.6577, 0.6585, 0.4035], [0.7292, -0.2372, 0.6419, -0.3799], [0.0000, 0.0000, 0.0000, 1.0000]]], device=device, dtype=dtype) expected_2 = torch.tensor([[[[[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]], [[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]], [[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]]]]], device=device, dtype=dtype) out, mat = f(input) assert_allclose(out, expected, rtol=1e-6, atol=1e-4) assert_allclose(mat, expected_transform, rtol=1e-6, atol=1e-4) assert_allclose(f1(input), expected_2, rtol=1e-6, atol=1e-4)
def test_random_rotation(self, device): # This is included in doctest torch.manual_seed(0) # for random reproductibility f = RandomRotation3D(degrees=45.0, return_transform=True) f1 = RandomRotation3D(degrees=45.0) input = torch.tensor([[[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]], [[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]], [[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]]]) # 3 x 4 x 4 input = input.to(device) expected = torch.tensor([[[[[0.0000, 0.0000, 0.6810, 0.5250], [0.5052, 0.0000, 0.0000, 0.0613], [0.1159, 0.1072, 0.5324, 0.0870], [0.0000, 0.0000, 0.1927, 0.0000]], [[0.0000, 0.1683, 0.6963, 0.1131], [0.0566, 0.0000, 0.5215, 0.2796], [0.0694, 0.6039, 1.4519, 1.1240], [0.0000, 0.1325, 0.1542, 0.2510]], [[0.0000, 0.2054, 0.0000, 0.0000], [0.0026, 0.6088, 0.7358, 0.2319], [0.1262, 1.0830, 1.3687, 1.4940], [0.0000, 0.0416, 0.2012, 0.3124]]]]]) expected = expected.to(device) expected_transform = torch.tensor([[[0.6523, 0.3666, -0.6635, 0.6352], [-0.6185, 0.7634, -0.1862, 1.4689], [0.4382, 0.5318, 0.7247, -1.1797], [0.0000, 0.0000, 0.0000, 1.0000]]]) expected_transform = expected_transform.to(device) expected_2 = torch.tensor([[[[[0.0000, 0.4771, 0.0243, 0.0000], [0.0000, 0.1652, 0.0000, 0.6771], [0.1668, 1.1430, 0.7131, 0.2692], [0.0285, 0.7100, 0.6012, 0.0000]], [[0.0000, 0.3068, 0.0000, 0.0000], [0.0000, 0.3175, 0.0000, 0.6602], [0.1330, 1.1962, 0.9750, 0.0000], [0.0648, 0.9818, 0.9785, 0.0000]], [[0.0000, 0.1136, 0.0000, 0.0000], [0.0518, 0.4617, 0.0000, 0.4928], [0.0407, 1.0954, 1.1413, 0.0000], [0.0587, 0.8768, 1.1815, 0.0000]]]]]) expected_2 = expected_2.to(device) out, mat = f(input) assert_allclose(out, expected, rtol=1e-6, atol=1e-4) assert_allclose(mat, expected_transform, rtol=1e-6, atol=1e-4) assert_allclose(f1(input), expected_2, rtol=1e-6, atol=1e-4)
def test_sequential(self, device, dtype): torch.manual_seed(0) # for random reproductibility f = nn.Sequential( RandomRotation3D(torch.tensor([-45.0, 90]), return_transform=True), RandomRotation3D(10.4, return_transform=True), ) f1 = nn.Sequential( RandomRotation3D(torch.tensor([-45.0, 90]), return_transform=True), RandomRotation3D(10.4), ) input = torch.tensor([[[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]], [[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]], [[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]]], device=device, dtype=dtype) # 3 x 4 x 4 expected = torch.tensor([[[[[0.2752, 0.0000, 0.0000, 0.0000], [0.5767, 0.0059, 0.6440, 0.4307], [0.0000, 0.2793, 0.6638, 0.5716], [0.0000, 0.0049, 0.0000, 0.0685]], [[0.0000, 0.0000, 0.1806, 0.0000], [0.2138, 0.0000, 0.9061, 0.7966], [0.0657, 0.5395, 1.4299, 1.2912], [0.0000, 0.3600, 0.3088, 0.3655]], [[0.0000, 0.6515, 0.8861, 0.0000], [0.0000, 0.0000, 0.2278, 0.0000], [0.0027, 0.4403, 1.5462, 1.3480], [0.0000, 0.1182, 0.6297, 0.8623]]]]], device=device, dtype=dtype) expected_transform = torch.tensor([[[0.6306, 0.6496, -0.4247, 0.0044], [-0.3843, 0.7367, 0.5563, 0.4151], [0.6743, -0.1876, 0.7142, -0.4443], [0.0000, 0.0000, 0.0000, 1.0000]]], device=device, dtype=dtype) expected_transform_2 = torch.tensor([[[0.9611, 0.0495, -0.2717, 0.2557], [0.1255, 0.7980, 0.5894, -0.4747], [0.2460, -0.6006, 0.7608, 0.7710], [0.0000, 0.0000, 0.0000, 1.0000]]], device=device, dtype=dtype) out, mat = f(input) _, mat_2 = f1(input) assert_allclose(out, expected, rtol=1e-6, atol=1e-4) assert_allclose(mat, expected_transform, rtol=1e-6, atol=1e-4) assert_allclose(mat_2, expected_transform_2, rtol=1e-6, atol=1e-4)
def test_gradcheck(self, device): torch.manual_seed(0) # for random reproductibility input = torch.rand((3, 3, 3)).to(device) # 3 x 3 x 3 input = utils.tensor_to_gradcheck_var(input) # to var assert gradcheck(RandomRotation3D(degrees=(15.0, 15.0), p=1.), (input, ), raise_exception=True)
def test_smoke(self): f = RandomRotation3D(degrees=45.5) repr = """RandomRotation3D(degrees=tensor([[-45.5000, 45.5000], [-45.5000, 45.5000], [-45.5000, 45.5000]]), resample=BILINEAR, align_corners=False, p=0.5, """\ """p_batch=1.0, same_on_batch=False, return_transform=False)""" assert str(f) == repr
def test_batch_random_rotation(self, device, dtype): torch.manual_seed(24) # for random reproductibility f = RandomRotation3D(degrees=45.0, return_transform=True) input = torch.tensor([[[[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]], [[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]], [[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]]]], device=device, dtype=dtype) # 1 x 1 x 4 x 4 expected = torch.tensor([[[[[1.0000, 0.0000, 0.0000, 2.0000], [0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 1.0000, 2.0000, 0.0000], [0.0000, 0.0000, 1.0000, 2.0000]], [[1.0000, 0.0000, 0.0000, 2.0000], [0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 1.0000, 2.0000, 0.0000], [0.0000, 0.0000, 1.0000, 2.0000]], [[1.0000, 0.0000, 0.0000, 2.0000], [0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 1.0000, 2.0000, 0.0000], [0.0000, 0.0000, 1.0000, 2.0000]]]], [[[[0.0000, 0.0726, 0.0000, 0.0000], [0.1038, 1.0134, 0.5566, 0.1519], [0.0000, 1.0849, 1.1068, 0.0000], [0.1242, 1.1065, 0.9681, 0.0000]], [[0.0000, 0.0047, 0.0166, 0.0000], [0.0579, 0.4459, 0.0000, 0.4728], [0.1864, 1.3349, 0.7530, 0.3251], [0.1431, 1.2481, 0.4471, 0.0000]], [[0.0000, 0.4840, 0.2314, 0.0000], [0.0000, 0.0328, 0.0000, 0.1434], [0.1899, 0.5580, 0.0000, 0.9170], [0.0000, 0.2042, 0.1571, 0.0855]]]]], device=device, dtype=dtype) expected_transform = torch.tensor([[[1.0000, 0.0000, 0.0000, 0.0000], [0.0000, 1.0000, 0.0000, 0.0000], [0.0000, 0.0000, 1.0000, 0.0000], [0.0000, 0.0000, 0.0000, 1.0000]], [[0.7522, -0.6326, -0.1841, 1.5047], [0.6029, 0.5482, 0.5796, -0.8063], [-0.2657, -0.5470, 0.7938, 1.4252], [0.0000, 0.0000, 0.0000, 1.0000]]], device=device, dtype=dtype) input = input.repeat(2, 1, 1, 1, 1) # 5 x 4 x 4 x 3 out, mat = f(input) assert_allclose(out, expected, rtol=1e-6, atol=1e-4) assert_allclose(mat, expected_transform, rtol=1e-6, atol=1e-4)
def test_batch_random_rotation(self, device, dtype): torch.manual_seed(0) # for random reproductibility f = RandomRotation3D(degrees=45.0, return_transform=True) input = torch.tensor([[[[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]], [[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]], [[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]]]], device=device, dtype=dtype) # 1 x 1 x 4 x 4 expected = torch.tensor([[[[[0.0000, 0.5106, 0.1146, 0.0000], [0.0000, 0.1261, 0.0000, 0.4723], [0.1714, 0.9931, 0.5442, 0.4684], [0.0193, 0.5802, 0.4195, 0.0000]], [[0.0000, 0.2386, 0.0000, 0.0000], [0.0187, 0.3527, 0.0000, 0.6119], [0.1294, 1.2251, 0.9130, 0.0942], [0.0962, 1.0769, 0.8448, 0.0000]], [[0.0000, 0.0202, 0.0000, 0.0000], [0.1092, 0.5845, 0.1038, 0.4598], [0.0000, 1.1218, 1.0796, 0.0000], [0.0780, 0.9513, 1.1278, 0.0000]]]], [[[[1.0000, 0.0000, 0.0000, 2.0000], [0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 1.0000, 2.0000, 0.0000], [0.0000, 0.0000, 1.0000, 2.0000]], [[1.0000, 0.0000, 0.0000, 2.0000], [0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 1.0000, 2.0000, 0.0000], [0.0000, 0.0000, 1.0000, 2.0000]], [[1.0000, 0.0000, 0.0000, 2.0000], [0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 1.0000, 2.0000, 0.0000], [0.0000, 0.0000, 1.0000, 2.0000]]]]], device=device, dtype=dtype) expected_transform = torch.tensor([[[0.7894, -0.6122, 0.0449, 1.1892], [0.5923, 0.7405, -0.3176, -0.1816], [0.1612, 0.2773, 0.9472, -0.6049], [0.0000, 0.0000, 0.0000, 1.0000]], [[1.0000, 0.0000, 0.0000, 0.0000], [0.0000, 1.0000, 0.0000, 0.0000], [0.0000, 0.0000, 1.0000, 0.0000], [0.0000, 0.0000, 0.0000, 1.0000]]], device=device, dtype=dtype) input = input.repeat(2, 1, 1, 1, 1) # 5 x 4 x 4 x 3 out, mat = f(input) assert_allclose(out, expected, rtol=1e-6, atol=1e-4) assert_allclose(mat, expected_transform, rtol=1e-6, atol=1e-4)
def test_batch_random_rotation(self, device): torch.manual_seed(0) # for random reproductibility f = RandomRotation3D(degrees=45.0, return_transform=True) input = torch.tensor([[[[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]], [[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]], [[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]]]]) # 1 x 1 x 4 x 4 input = input.to(device) expected = torch.tensor( [[[[[9.9412e-01, 0.0000e+00, 8.5407e-03, 1.9535e+00], [1.7328e-05, 3.8945e-03, 1.5617e-02, 1.0553e-04], [0.0000e+00, 9.8295e-01, 1.9789e+00, 4.9531e-02], [0.0000e+00, 0.0000e+00, 9.7034e-01, 1.9548e+00]], [[9.6646e-01, 0.0000e+00, 4.3866e-02, 1.9559e+00], [1.1586e-02, 0.0000e+00, 1.0260e-04, 0.0000e+00], [4.4472e-03, 9.9659e-01, 1.9833e+00, 3.4181e-05], [0.0000e+00, 7.8456e-03, 9.9959e-01, 1.9956e+00]], [[9.3772e-01, 0.0000e+00, 7.8179e-02, 1.8983e+00], [2.2707e-02, 0.0000e+00, 9.6477e-04, 2.2624e-02], [2.1575e-02, 9.9975e-01, 1.9243e+00, 0.0000e+00], [3.1300e-04, 3.2790e-02, 1.0249e+00, 1.9474e+00]]]], [[[[5.9268e-01, 4.6201e-01, 0.0000e+00, 1.3414e-01], [5.0854e-02, 0.0000e+00, 1.2416e-02, 1.1548e+00], [6.1057e-01, 9.5876e-01, 1.9510e-01, 0.0000e+00], [7.4132e-01, 8.6556e-01, 1.9031e-01, 0.0000e+00]], [[0.0000e+00, 1.0143e-01, 2.2231e-01, 0.0000e+00], [0.0000e+00, 1.8495e-01, 5.0593e-01, 4.9479e-01], [4.9456e-02, 5.0849e-01, 1.5325e+00, 7.8474e-01], [0.0000e+00, 4.7385e-01, 1.3469e+00, 1.1854e+00]], [[0.0000e+00, 0.0000e+00, 0.0000e+00, 9.7841e-02], [0.0000e+00, 0.0000e+00, 1.2882e-01, 8.0175e-01], [0.0000e+00, 0.0000e+00, 0.0000e+00, 7.2928e-01], [0.0000e+00, 0.0000e+00, 0.0000e+00, 2.1100e-01]]]]]) expected = expected.to(device) expected_transform = torch.tensor([[[0.7559, 0.2793, -0.5921, 0.7132], [-0.2756, 0.9561, 0.0991, 0.1928], [0.5938, 0.0883, 0.7998, -0.4259], [0.0000, 0.0000, 0.0000, 1.0000]], [[0.8194, -0.3079, -0.4836, 1.3678], [0.0754, 0.8941, -0.4415, 0.7456], [0.5683, 0.3253, 0.7558, -0.6899], [0.0000, 0.0000, 0.0000, 1.0000]]]) expected_transform = expected_transform.to(device) input = input.repeat(2, 1, 1, 1, 1) # 5 x 4 x 4 x 3 out, mat = f(input) assert_allclose(out, expected, rtol=1e-6, atol=1e-4) assert_allclose(mat, expected_transform, rtol=1e-6, atol=1e-4)
def test_batch_random_rotation(self, device): torch.manual_seed(0) # for random reproductibility f = RandomRotation3D(degrees=45.0, return_transform=True) input = torch.tensor([[[[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]], [[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]], [[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]]]]) # 1 x 1 x 4 x 4 input = input.to(device) expected = torch.tensor( [[[[[7.5651e-01, 6.4340e-02, 0.0000e+00, 0.0000e+00], [8.9954e-02, 2.0099e-01, 8.2089e-01, 4.4695e-01], [0.0000e+00, 4.8303e-01, 8.0751e-01, 1.1574e+00], [0.0000e+00, 6.5891e-02, 1.7392e-01, 2.9013e-01]], [[4.0104e-01, 0.0000e+00, 1.9018e-02, 7.2668e-04], [3.5247e-01, 0.0000e+00, 6.5445e-01, 3.7179e-01], [5.4804e-02, 7.0015e-01, 1.7578e+00, 1.2048e+00], [3.7536e-02, 3.8235e-01, 9.0737e-01, 1.1033e+00]], [[1.2648e-02, 0.0000e+00, 9.4951e-01, 4.6696e-01], [1.2791e-01, 0.0000e+00, 8.2977e-02, 0.0000e+00], [1.3047e-02, 2.8671e-01, 8.3294e-01, 1.7991e-01], [0.0000e+00, 5.8076e-02, 6.7866e-01, 1.5130e+00]]]], [[[[1.0000e+00, 0.0000e+00, 0.0000e+00, 2.0000e+00], [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [0.0000e+00, 1.0000e+00, 2.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00, 1.0000e+00, 2.0000e+00]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 2.0000e+00], [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [0.0000e+00, 1.0000e+00, 2.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00, 1.0000e+00, 2.0000e+00]], [[1.0000e+00, 0.0000e+00, 0.0000e+00, 2.0000e+00], [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00], [0.0000e+00, 1.0000e+00, 2.0000e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00, 1.0000e+00, 2.0000e+00]]]]]) expected = expected.to(device) expected_transform = torch.tensor([[[0.8017, 0.4358, -0.4090, 0.0527], [-0.0877, 0.7627, 0.6408, -0.1533], [0.5912, -0.4779, 0.6497, 0.1803], [0.0000, 0.0000, 0.0000, 1.0000]], [[1.0000, 0.0000, 0.0000, 0.0000], [0.0000, 1.0000, 0.0000, 0.0000], [0.0000, 0.0000, 1.0000, 0.0000], [0.0000, 0.0000, 0.0000, 1.0000]]]) expected_transform = expected_transform.to(device) input = input.repeat(2, 1, 1, 1, 1) # 5 x 4 x 4 x 3 out, mat = f(input) assert_allclose(out, expected, rtol=1e-6, atol=1e-4) assert_allclose(mat, expected_transform, rtol=1e-6, atol=1e-4)
def test_random_rotation(self, device, dtype): # This is included in doctest torch.manual_seed(0) # for random reproductibility f = RandomRotation3D(degrees=45.0, return_transform=True) f1 = RandomRotation3D(degrees=45.0) input = torch.tensor([[[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]], [[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]], [[1., 0., 0., 2.], [0., 0., 0., 0.], [0., 1., 2., 0.], [0., 0., 1., 2.]]], device=device, dtype=dtype) # 3 x 4 x 4 expected = torch.tensor([[[[[0.0000, 0.0000, 0.6810, 0.5250], [0.5052, 0.0000, 0.0000, 0.0613], [0.1159, 0.1072, 0.5324, 0.0870], [0.0000, 0.0000, 0.1927, 0.0000]], [[0.0000, 0.1683, 0.6963, 0.1131], [0.0566, 0.0000, 0.5215, 0.2796], [0.0694, 0.6039, 1.4519, 1.1240], [0.0000, 0.1325, 0.1542, 0.2510]], [[0.0000, 0.2054, 0.0000, 0.0000], [0.0026, 0.6088, 0.7358, 0.2319], [0.1261, 1.0830, 1.3687, 1.4940], [0.0000, 0.0416, 0.2012, 0.3124]]]]], device=device, dtype=dtype) expected_transform = torch.tensor([[[0.6523, 0.3666, -0.6635, 0.6352], [-0.6185, 0.7634, -0.1862, 1.4689], [0.4382, 0.5318, 0.7247, -1.1797], [0.0000, 0.0000, 0.0000, 1.0000]]], device=device, dtype=dtype) out, mat = f(input) assert_allclose(out, expected, rtol=1e-6, atol=1e-4) assert_allclose(mat, expected_transform, rtol=1e-6, atol=1e-4) torch.manual_seed(0) # for random reproductibility assert_allclose(f1(input), expected, rtol=1e-6, atol=1e-4)
def test_param(self, degrees, resample, align_corners, return_transform, same_on_batch, device, dtype): _degrees = (degrees if isinstance(degrees, (int, float, list, tuple)) else nn.Parameter(degrees.clone().to(device=device, dtype=dtype))) torch.manual_seed(0) input = torch.randint( 255, (2, 3, 10, 10, 10), device=device, dtype=dtype) / 255.0 aug = RandomRotation3D( _degrees, resample, align_corners=align_corners, return_transform=return_transform, same_on_batch=same_on_batch, p=1.0, ) if return_transform: output, _ = aug(input) else: output = aug(input) if len(list(aug.parameters())) != 0: mse = nn.MSELoss() opt = torch.optim.SGD(aug.parameters(), lr=10) loss = mse(output, torch.ones_like(output) * 2) # to ensure that a big loss value could be obtained loss.backward() opt.step() if not isinstance(degrees, (int, float, list, tuple)): assert isinstance(aug._param_generator.degrees, torch.Tensor) # Assert if param not updated if resample == 'nearest' and aug._param_generator.degrees.is_cuda: # grid_sample in nearest mode and cuda device returns nan than 0 pass elif resample == 'nearest' or torch.all( aug._param_generator.degrees._grad == 0.0): # grid_sample will return grad = 0 for resample nearest # https://discuss.pytorch.org/t/autograd-issue-with-f-grid-sample/76894 assert (degrees.to(device=device, dtype=dtype) - aug._param_generator.degrees.data).sum() == 0 else: assert (degrees.to(device=device, dtype=dtype) - aug._param_generator.degrees.data).sum() != 0
def test_same_on_batch(self, device, dtype): f = RandomRotation3D(degrees=40, same_on_batch=True) input = torch.eye(6, device=device, dtype=dtype).unsqueeze(dim=0).unsqueeze(dim=0).repeat(2, 3, 6, 1, 1) res = f(input) assert (res[0] == res[1]).all()
def smoke_test(self, device): f = RandomRotation3D(degrees=45.5) repr = "RandomRotation3D(degrees=45.5, return_transform=False)" assert str(f) == repr
def test_sequential(self, device, dtype): torch.manual_seed(24) # for random reproductibility f = AugmentationSequential( RandomRotation3D(torch.tensor([-45.0, 90])), RandomRotation3D(10.4), ) input = torch.tensor( [ [[1.0, 0.0, 0.0, 2.0], [0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 1.0, 2.0]], [[1.0, 0.0, 0.0, 2.0], [0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 1.0, 2.0]], [[1.0, 0.0, 0.0, 2.0], [0.0, 0.0, 0.0, 0.0], [0.0, 1.0, 2.0, 0.0], [0.0, 0.0, 1.0, 2.0]], ], device=device, dtype=dtype, ) # 3 x 4 x 4 expected = torch.tensor( [[[ [ [0.3431, 0.1239, 0.0000, 1.0348], [0.0000, 0.2035, 0.1139, 0.1770], [0.0789, 0.9057, 1.7780, 0.0000], [0.0000, 0.2286, 1.2498, 1.2643], ], [ [0.5460, 0.2131, 0.0000, 1.1453], [0.0000, 0.0899, 0.0000, 0.4293], [0.0797, 1.0193, 1.6677, 0.0000], [0.0000, 0.2458, 1.2765, 1.0920], ], [ [0.6322, 0.2614, 0.0000, 0.9207], [0.0000, 0.0037, 0.0000, 0.6551], [0.0689, 0.9251, 1.3442, 0.0000], [0.0000, 0.2449, 0.9856, 0.6862], ], ]]], device=device, dtype=dtype, ) expected_transform = torch.tensor( [[ [0.9857, -0.1686, -0.0019, 0.2762], [0.1668, 0.9739, 0.1538, -0.3650], [-0.0241, -0.1520, 0.9881, 0.2760], [0.0000, 0.0000, 0.0000, 1.0000], ]], device=device, dtype=dtype, ) out = f(input) assert_close(out, expected, rtol=1e-6, atol=1e-4) assert_close(f.transform_matrix, expected_transform, rtol=1e-6, atol=1e-4)