def test_invalid_param_combinations(self, depth, height, width, degrees, translate, scale, shear, device, dtype): with pytest.raises(Exception): random_affine_generator3d( batch_size=8, depth=depth, height=height, width=width, degrees=degrees.to(device=device, dtype=dtype), translate=translate.to(device=device, dtype=dtype) if translate is not None else None, scale=scale.to(device=device, dtype=dtype) if scale is not None else None, shears=shear.to(device=device, dtype=dtype) if shear is not None else None, )
def test_valid_param_combinations( self, batch_size, depth, height, width, degrees, translate, scale, shear, same_on_batch, device, dtype ): random_affine_generator3d( batch_size=batch_size, depth=depth, height=height, width=width, degrees=degrees.to(device=device, dtype=dtype), translate=translate.to(device=device, dtype=dtype) if translate is not None else None, scale=scale.to(device=device, dtype=dtype) if scale is not None else None, shears=shear.to(device=device, dtype=dtype) if shear is not None else None, same_on_batch=same_on_batch, )
def test_same_on_batch(self, device, dtype): torch.manual_seed(42) degrees = torch.tensor([[10, 20], [10, 20], [10, 20]]) translate = torch.tensor([0.1, 0.1, 0.1]) scale = torch.tensor([[0.7, 1.2], [0.7, 1.2], [0.7, 1.2]]) shear = torch.tensor([[0, 20], [0, 20], [0, 20], [0, 20], [0, 20], [0, 20]]) res = random_affine_generator3d( batch_size=2, depth=200, height=200, width=200, degrees=degrees.to(device=device, dtype=dtype), translate=translate.to(device=device, dtype=dtype) if translate is not None else None, scale=scale.to(device=device, dtype=dtype) if scale is not None else None, shears=shear.to(device=device, dtype=dtype) if shear is not None else None, same_on_batch=True, ) expected = dict( translations=torch.tensor( [[-9.7371, 11.7457, 17.6309], [-9.7371, 11.7457, 17.6309]], device=device, dtype=dtype), center=torch.tensor( [[99.5000, 99.5000, 99.5000], [99.5000, 99.5000, 99.5000]], device=device, dtype=dtype), scale=torch.tensor( [[1.1797, 0.8952, 1.0004], [1.1797, 0.8952, 1.0004]], device=device, dtype=dtype), angles=torch.tensor( [[18.8227, 19.1500, 13.8286], [18.8227, 19.1500, 13.8286]], device=device, dtype=dtype), sxy=torch.tensor([2.6637, 2.6637], device=device, dtype=dtype), sxz=torch.tensor([18.6920, 18.6920], device=device, dtype=dtype), syx=torch.tensor([11.8716, 11.8716], device=device, dtype=dtype), syz=torch.tensor([17.3881, 17.3881], device=device, dtype=dtype), szx=torch.tensor([11.3543, 11.3543], device=device, dtype=dtype), szy=torch.tensor([14.8219, 14.8219], device=device, dtype=dtype), ) assert res.keys() == expected.keys() assert_close(res['translations'], expected['translations'], rtol=1e-4, atol=1e-4) assert_close(res['center'], expected['center'], rtol=1e-4, atol=1e-4) assert_close(res['scale'], expected['scale'], rtol=1e-4, atol=1e-4) assert_close(res['angles'], expected['angles'], rtol=1e-4, atol=1e-4) assert_close(res['sxy'], expected['sxy'], rtol=1e-4, atol=1e-4) assert_close(res['sxz'], expected['sxz'], rtol=1e-4, atol=1e-4) assert_close(res['syx'], expected['syx'], rtol=1e-4, atol=1e-4) assert_close(res['syz'], expected['syz'], rtol=1e-4, atol=1e-4) assert_close(res['szx'], expected['szx'], rtol=1e-4, atol=1e-4) assert_close(res['szy'], expected['szy'], rtol=1e-4, atol=1e-4)
def test_random_gen(self, device, dtype): torch.manual_seed(42) degrees = torch.tensor([[10, 20], [10, 20], [10, 20]]) translate = torch.tensor([0.1, 0.1, 0.1]) scale = torch.tensor([[0.7, 1.2], [0.7, 1.2], [0.7, 1.2]]) shear = torch.tensor([[0, 20], [0, 20], [0, 20], [0, 20], [0, 20], [0, 20]]) res = random_affine_generator3d( batch_size=2, depth=200, height=200, width=200, degrees=degrees.to(device=device, dtype=dtype), translate=translate.to(device=device, dtype=dtype) if translate is not None else None, scale=scale.to(device=device, dtype=dtype) if scale is not None else None, shears=shear.to(device=device, dtype=dtype) if shear is not None else None, ) expected = dict( translations=torch.tensor( [[14.7762, 9.6438, 15.4177], [2.7086, -2.8238, 2.9562]], device=device, dtype=dtype), center=torch.tensor( [[99.5000, 99.5000, 99.5000], [99.5000, 99.5000, 99.5000]], device=device, dtype=dtype), scale=torch.tensor( [[0.8283, 1.1704, 1.1673], [1.0968, 0.7666, 0.9968]], device=device, dtype=dtype), angles=torch.tensor( [[18.8227, 13.8286, 13.9045], [19.1500, 19.5931, 16.0090]], device=device, dtype=dtype), sxy=torch.tensor([5.3316, 12.5490], device=device, dtype=dtype), sxz=torch.tensor([5.3926, 8.8273], device=device, dtype=dtype), syx=torch.tensor([5.9384, 16.6337], device=device, dtype=dtype), syz=torch.tensor([2.1063, 5.3899], device=device, dtype=dtype), szx=torch.tensor([7.1763, 3.9873], device=device, dtype=dtype), szy=torch.tensor([10.9438, 0.1232], device=device, dtype=dtype), ) assert res.keys() == expected.keys() assert_close(res['translations'], expected['translations'], rtol=1e-4, atol=1e-4) assert_close(res['center'], expected['center'], rtol=1e-4, atol=1e-4) assert_close(res['scale'], expected['scale'], rtol=1e-4, atol=1e-4) assert_close(res['angles'], expected['angles'], rtol=1e-4, atol=1e-4) assert_close(res['sxy'], expected['sxy'], rtol=1e-4, atol=1e-4) assert_close(res['sxz'], expected['sxz'], rtol=1e-4, atol=1e-4) assert_close(res['syx'], expected['syx'], rtol=1e-4, atol=1e-4) assert_close(res['syz'], expected['syz'], rtol=1e-4, atol=1e-4) assert_close(res['szx'], expected['szx'], rtol=1e-4, atol=1e-4) assert_close(res['szy'], expected['szy'], rtol=1e-4, atol=1e-4)
def test_same_on_batch(self, device, dtype): torch.manual_seed(42) degrees = torch.tensor([[10, 20], [10, 20], [10, 20]]) translate = torch.tensor([0.1, 0.1, 0.1]) scale = torch.tensor([[0.7, 1.2], [0.7, 1.2], [0.7, 1.2]]) shear = torch.tensor([[0, 20], [0, 20], [0, 20], [0, 20], [0, 20], [0, 20]]) res = random_affine_generator3d( batch_size=2, depth=200, height=200, width=200, degrees=degrees.to(device=device, dtype=dtype), translate=translate.to(device=device, dtype=dtype) if translate is not None else None, scale=scale.to(device=device, dtype=dtype) if scale is not None else None, shears=shear.to(device=device, dtype=dtype) if shear is not None else None, same_on_batch=True) expected = dict( translations=torch.tensor( [[18.2094, 17.1501, -16.6583], [18.2094, 17.1501, -16.6583]], device=device, dtype=dtype), center=torch.tensor( [[99.5000, 99.5000, 99.5000], [99.5000, 99.5000, 99.5000]], device=device, dtype=dtype), scale=torch.tensor( [[0.7263, 0.9631, 0.9384], [0.7263, 0.9631, 0.9384]], device=device, dtype=dtype), angles=torch.tensor( [[10.5815, 10.6291, 11.2359], [10.5815, 10.6291, 11.2359]], device=device, dtype=dtype), sxy=torch.tensor([2.6528, 2.6528], device=device, dtype=dtype), sxz=torch.tensor([3.1411, 3.1411], device=device, dtype=dtype), syx=torch.tensor([7.5073, 7.5073], device=device, dtype=dtype), syz=torch.tensor([16.8504, 16.8504], device=device, dtype=dtype), szx=torch.tensor([17.4100, 17.4100], device=device, dtype=dtype), szy=torch.tensor([7.5507, 7.5507], device=device, dtype=dtype), ) assert res.keys() == expected.keys() assert_allclose(res['translations'], expected['translations'], rtol=1e-4, atol=1e-4) assert_allclose(res['center'], expected['center'], rtol=1e-4, atol=1e-4) assert_allclose(res['scale'], expected['scale'], rtol=1e-4, atol=1e-4) assert_allclose(res['angles'], expected['angles'], rtol=1e-4, atol=1e-4) assert_allclose(res['sxy'], expected['sxy'], rtol=1e-4, atol=1e-4) assert_allclose(res['sxz'], expected['sxz'], rtol=1e-4, atol=1e-4) assert_allclose(res['syx'], expected['syx'], rtol=1e-4, atol=1e-4) assert_allclose(res['syz'], expected['syz'], rtol=1e-4, atol=1e-4) assert_allclose(res['szx'], expected['szx'], rtol=1e-4, atol=1e-4) assert_allclose(res['szy'], expected['szy'], rtol=1e-4, atol=1e-4)
def test_random_gen(self, device, dtype): torch.manual_seed(42) degrees = torch.tensor([[10, 20], [10, 20], [10, 20]]) translate = torch.tensor([0.1, 0.1, 0.1]) scale = torch.tensor([[0.7, 1.2], [0.7, 1.2], [0.7, 1.2]]) shear = torch.tensor([[0, 20], [0, 20], [0, 20], [0, 20], [0, 20], [0, 20]]) res = random_affine_generator3d( batch_size=2, depth=200, height=200, width=200, degrees=degrees.to(device=device, dtype=dtype), translate=translate.to(device=device, dtype=dtype) if translate is not None else None, scale=scale.to(device=device, dtype=dtype) if scale is not None else None, shears=shear.to(device=device, dtype=dtype) if shear is not None else None) expected = dict( translations=torch.tensor( [[13.7008, -4.8987, -16.4756], [14.8200, 4.4975, 8.0473]], device=device, dtype=dtype), center=torch.tensor( [[99.5000, 99.5000, 99.5000], [99.5000, 99.5000, 99.5000]], device=device, dtype=dtype), scale=torch.tensor( [[1.1776, 0.7418, 0.7785], [1.1644, 0.7663, 0.8877]], device=device, dtype=dtype), angles=torch.tensor( [[10.5815, 11.2359, 15.2617], [10.6291, 10.5258, 14.7678]], device=device, dtype=dtype), sxy=torch.tensor([12.4681, 8.7456], device=device, dtype=dtype), sxz=torch.tensor([1.4947, 13.6686], device=device, dtype=dtype), syx=torch.tensor([6.2448, 6.1812], device=device, dtype=dtype), syz=torch.tensor([0.6268, 0.8073], device=device, dtype=dtype), szx=torch.tensor([18.6382, 3.0425], device=device, dtype=dtype), szy=torch.tensor([5.3009, 2.6087], device=device, dtype=dtype), ) assert res.keys() == expected.keys() assert_allclose(res['translations'], expected['translations'], rtol=1e-4, atol=1e-4) assert_allclose(res['center'], expected['center'], rtol=1e-4, atol=1e-4) assert_allclose(res['scale'], expected['scale'], rtol=1e-4, atol=1e-4) assert_allclose(res['angles'], expected['angles'], rtol=1e-4, atol=1e-4) assert_allclose(res['sxy'], expected['sxy'], rtol=1e-4, atol=1e-4) assert_allclose(res['sxz'], expected['sxz'], rtol=1e-4, atol=1e-4) assert_allclose(res['syx'], expected['syx'], rtol=1e-4, atol=1e-4) assert_allclose(res['syz'], expected['syz'], rtol=1e-4, atol=1e-4) assert_allclose(res['szx'], expected['szx'], rtol=1e-4, atol=1e-4) assert_allclose(res['szy'], expected['szy'], rtol=1e-4, atol=1e-4)