def test_invalid_param_combinations(self, depth, height, width, degrees, translate, scale, shear, device, dtype):
     with pytest.raises(Exception):
         random_affine_generator3d(
             batch_size=8,
             depth=depth,
             height=height,
             width=width,
             degrees=degrees.to(device=device, dtype=dtype),
             translate=translate.to(device=device, dtype=dtype) if translate is not None else None,
             scale=scale.to(device=device, dtype=dtype) if scale is not None else None,
             shears=shear.to(device=device, dtype=dtype) if shear is not None else None,
         )
 def test_valid_param_combinations(
     self, batch_size, depth, height, width, degrees, translate, scale, shear, same_on_batch, device, dtype
 ):
     random_affine_generator3d(
         batch_size=batch_size,
         depth=depth,
         height=height,
         width=width,
         degrees=degrees.to(device=device, dtype=dtype),
         translate=translate.to(device=device, dtype=dtype) if translate is not None else None,
         scale=scale.to(device=device, dtype=dtype) if scale is not None else None,
         shears=shear.to(device=device, dtype=dtype) if shear is not None else None,
         same_on_batch=same_on_batch,
     )
 def test_same_on_batch(self, device, dtype):
     torch.manual_seed(42)
     degrees = torch.tensor([[10, 20], [10, 20], [10, 20]])
     translate = torch.tensor([0.1, 0.1, 0.1])
     scale = torch.tensor([[0.7, 1.2], [0.7, 1.2], [0.7, 1.2]])
     shear = torch.tensor([[0, 20], [0, 20], [0, 20], [0, 20], [0, 20],
                           [0, 20]])
     res = random_affine_generator3d(
         batch_size=2,
         depth=200,
         height=200,
         width=200,
         degrees=degrees.to(device=device, dtype=dtype),
         translate=translate.to(device=device, dtype=dtype)
         if translate is not None else None,
         scale=scale.to(device=device, dtype=dtype)
         if scale is not None else None,
         shears=shear.to(device=device, dtype=dtype)
         if shear is not None else None,
         same_on_batch=True,
     )
     expected = dict(
         translations=torch.tensor(
             [[-9.7371, 11.7457, 17.6309], [-9.7371, 11.7457, 17.6309]],
             device=device,
             dtype=dtype),
         center=torch.tensor(
             [[99.5000, 99.5000, 99.5000], [99.5000, 99.5000, 99.5000]],
             device=device,
             dtype=dtype),
         scale=torch.tensor(
             [[1.1797, 0.8952, 1.0004], [1.1797, 0.8952, 1.0004]],
             device=device,
             dtype=dtype),
         angles=torch.tensor(
             [[18.8227, 19.1500, 13.8286], [18.8227, 19.1500, 13.8286]],
             device=device,
             dtype=dtype),
         sxy=torch.tensor([2.6637, 2.6637], device=device, dtype=dtype),
         sxz=torch.tensor([18.6920, 18.6920], device=device, dtype=dtype),
         syx=torch.tensor([11.8716, 11.8716], device=device, dtype=dtype),
         syz=torch.tensor([17.3881, 17.3881], device=device, dtype=dtype),
         szx=torch.tensor([11.3543, 11.3543], device=device, dtype=dtype),
         szy=torch.tensor([14.8219, 14.8219], device=device, dtype=dtype),
     )
     assert res.keys() == expected.keys()
     assert_close(res['translations'],
                  expected['translations'],
                  rtol=1e-4,
                  atol=1e-4)
     assert_close(res['center'], expected['center'], rtol=1e-4, atol=1e-4)
     assert_close(res['scale'], expected['scale'], rtol=1e-4, atol=1e-4)
     assert_close(res['angles'], expected['angles'], rtol=1e-4, atol=1e-4)
     assert_close(res['sxy'], expected['sxy'], rtol=1e-4, atol=1e-4)
     assert_close(res['sxz'], expected['sxz'], rtol=1e-4, atol=1e-4)
     assert_close(res['syx'], expected['syx'], rtol=1e-4, atol=1e-4)
     assert_close(res['syz'], expected['syz'], rtol=1e-4, atol=1e-4)
     assert_close(res['szx'], expected['szx'], rtol=1e-4, atol=1e-4)
     assert_close(res['szy'], expected['szy'], rtol=1e-4, atol=1e-4)
 def test_random_gen(self, device, dtype):
     torch.manual_seed(42)
     degrees = torch.tensor([[10, 20], [10, 20], [10, 20]])
     translate = torch.tensor([0.1, 0.1, 0.1])
     scale = torch.tensor([[0.7, 1.2], [0.7, 1.2], [0.7, 1.2]])
     shear = torch.tensor([[0, 20], [0, 20], [0, 20], [0, 20], [0, 20],
                           [0, 20]])
     res = random_affine_generator3d(
         batch_size=2,
         depth=200,
         height=200,
         width=200,
         degrees=degrees.to(device=device, dtype=dtype),
         translate=translate.to(device=device, dtype=dtype)
         if translate is not None else None,
         scale=scale.to(device=device, dtype=dtype)
         if scale is not None else None,
         shears=shear.to(device=device, dtype=dtype)
         if shear is not None else None,
     )
     expected = dict(
         translations=torch.tensor(
             [[14.7762, 9.6438, 15.4177], [2.7086, -2.8238, 2.9562]],
             device=device,
             dtype=dtype),
         center=torch.tensor(
             [[99.5000, 99.5000, 99.5000], [99.5000, 99.5000, 99.5000]],
             device=device,
             dtype=dtype),
         scale=torch.tensor(
             [[0.8283, 1.1704, 1.1673], [1.0968, 0.7666, 0.9968]],
             device=device,
             dtype=dtype),
         angles=torch.tensor(
             [[18.8227, 13.8286, 13.9045], [19.1500, 19.5931, 16.0090]],
             device=device,
             dtype=dtype),
         sxy=torch.tensor([5.3316, 12.5490], device=device, dtype=dtype),
         sxz=torch.tensor([5.3926, 8.8273], device=device, dtype=dtype),
         syx=torch.tensor([5.9384, 16.6337], device=device, dtype=dtype),
         syz=torch.tensor([2.1063, 5.3899], device=device, dtype=dtype),
         szx=torch.tensor([7.1763, 3.9873], device=device, dtype=dtype),
         szy=torch.tensor([10.9438, 0.1232], device=device, dtype=dtype),
     )
     assert res.keys() == expected.keys()
     assert_close(res['translations'],
                  expected['translations'],
                  rtol=1e-4,
                  atol=1e-4)
     assert_close(res['center'], expected['center'], rtol=1e-4, atol=1e-4)
     assert_close(res['scale'], expected['scale'], rtol=1e-4, atol=1e-4)
     assert_close(res['angles'], expected['angles'], rtol=1e-4, atol=1e-4)
     assert_close(res['sxy'], expected['sxy'], rtol=1e-4, atol=1e-4)
     assert_close(res['sxz'], expected['sxz'], rtol=1e-4, atol=1e-4)
     assert_close(res['syx'], expected['syx'], rtol=1e-4, atol=1e-4)
     assert_close(res['syz'], expected['syz'], rtol=1e-4, atol=1e-4)
     assert_close(res['szx'], expected['szx'], rtol=1e-4, atol=1e-4)
     assert_close(res['szy'], expected['szy'], rtol=1e-4, atol=1e-4)
 def test_same_on_batch(self, device, dtype):
     torch.manual_seed(42)
     degrees = torch.tensor([[10, 20], [10, 20], [10, 20]])
     translate = torch.tensor([0.1, 0.1, 0.1])
     scale = torch.tensor([[0.7, 1.2], [0.7, 1.2], [0.7, 1.2]])
     shear = torch.tensor([[0, 20], [0, 20], [0, 20], [0, 20], [0, 20],
                           [0, 20]])
     res = random_affine_generator3d(
         batch_size=2,
         depth=200,
         height=200,
         width=200,
         degrees=degrees.to(device=device, dtype=dtype),
         translate=translate.to(device=device, dtype=dtype)
         if translate is not None else None,
         scale=scale.to(device=device, dtype=dtype)
         if scale is not None else None,
         shears=shear.to(device=device, dtype=dtype)
         if shear is not None else None,
         same_on_batch=True)
     expected = dict(
         translations=torch.tensor(
             [[18.2094, 17.1501, -16.6583], [18.2094, 17.1501, -16.6583]],
             device=device,
             dtype=dtype),
         center=torch.tensor(
             [[99.5000, 99.5000, 99.5000], [99.5000, 99.5000, 99.5000]],
             device=device,
             dtype=dtype),
         scale=torch.tensor(
             [[0.7263, 0.9631, 0.9384], [0.7263, 0.9631, 0.9384]],
             device=device,
             dtype=dtype),
         angles=torch.tensor(
             [[10.5815, 10.6291, 11.2359], [10.5815, 10.6291, 11.2359]],
             device=device,
             dtype=dtype),
         sxy=torch.tensor([2.6528, 2.6528], device=device, dtype=dtype),
         sxz=torch.tensor([3.1411, 3.1411], device=device, dtype=dtype),
         syx=torch.tensor([7.5073, 7.5073], device=device, dtype=dtype),
         syz=torch.tensor([16.8504, 16.8504], device=device, dtype=dtype),
         szx=torch.tensor([17.4100, 17.4100], device=device, dtype=dtype),
         szy=torch.tensor([7.5507, 7.5507], device=device, dtype=dtype),
     )
     assert res.keys() == expected.keys()
     assert_allclose(res['translations'],
                     expected['translations'],
                     rtol=1e-4,
                     atol=1e-4)
     assert_allclose(res['center'],
                     expected['center'],
                     rtol=1e-4,
                     atol=1e-4)
     assert_allclose(res['scale'], expected['scale'], rtol=1e-4, atol=1e-4)
     assert_allclose(res['angles'],
                     expected['angles'],
                     rtol=1e-4,
                     atol=1e-4)
     assert_allclose(res['sxy'], expected['sxy'], rtol=1e-4, atol=1e-4)
     assert_allclose(res['sxz'], expected['sxz'], rtol=1e-4, atol=1e-4)
     assert_allclose(res['syx'], expected['syx'], rtol=1e-4, atol=1e-4)
     assert_allclose(res['syz'], expected['syz'], rtol=1e-4, atol=1e-4)
     assert_allclose(res['szx'], expected['szx'], rtol=1e-4, atol=1e-4)
     assert_allclose(res['szy'], expected['szy'], rtol=1e-4, atol=1e-4)
 def test_random_gen(self, device, dtype):
     torch.manual_seed(42)
     degrees = torch.tensor([[10, 20], [10, 20], [10, 20]])
     translate = torch.tensor([0.1, 0.1, 0.1])
     scale = torch.tensor([[0.7, 1.2], [0.7, 1.2], [0.7, 1.2]])
     shear = torch.tensor([[0, 20], [0, 20], [0, 20], [0, 20], [0, 20],
                           [0, 20]])
     res = random_affine_generator3d(
         batch_size=2,
         depth=200,
         height=200,
         width=200,
         degrees=degrees.to(device=device, dtype=dtype),
         translate=translate.to(device=device, dtype=dtype)
         if translate is not None else None,
         scale=scale.to(device=device, dtype=dtype)
         if scale is not None else None,
         shears=shear.to(device=device, dtype=dtype)
         if shear is not None else None)
     expected = dict(
         translations=torch.tensor(
             [[13.7008, -4.8987, -16.4756], [14.8200, 4.4975, 8.0473]],
             device=device,
             dtype=dtype),
         center=torch.tensor(
             [[99.5000, 99.5000, 99.5000], [99.5000, 99.5000, 99.5000]],
             device=device,
             dtype=dtype),
         scale=torch.tensor(
             [[1.1776, 0.7418, 0.7785], [1.1644, 0.7663, 0.8877]],
             device=device,
             dtype=dtype),
         angles=torch.tensor(
             [[10.5815, 11.2359, 15.2617], [10.6291, 10.5258, 14.7678]],
             device=device,
             dtype=dtype),
         sxy=torch.tensor([12.4681, 8.7456], device=device, dtype=dtype),
         sxz=torch.tensor([1.4947, 13.6686], device=device, dtype=dtype),
         syx=torch.tensor([6.2448, 6.1812], device=device, dtype=dtype),
         syz=torch.tensor([0.6268, 0.8073], device=device, dtype=dtype),
         szx=torch.tensor([18.6382, 3.0425], device=device, dtype=dtype),
         szy=torch.tensor([5.3009, 2.6087], device=device, dtype=dtype),
     )
     assert res.keys() == expected.keys()
     assert_allclose(res['translations'],
                     expected['translations'],
                     rtol=1e-4,
                     atol=1e-4)
     assert_allclose(res['center'],
                     expected['center'],
                     rtol=1e-4,
                     atol=1e-4)
     assert_allclose(res['scale'], expected['scale'], rtol=1e-4, atol=1e-4)
     assert_allclose(res['angles'],
                     expected['angles'],
                     rtol=1e-4,
                     atol=1e-4)
     assert_allclose(res['sxy'], expected['sxy'], rtol=1e-4, atol=1e-4)
     assert_allclose(res['sxz'], expected['sxz'], rtol=1e-4, atol=1e-4)
     assert_allclose(res['syx'], expected['syx'], rtol=1e-4, atol=1e-4)
     assert_allclose(res['syz'], expected['syz'], rtol=1e-4, atol=1e-4)
     assert_allclose(res['szx'], expected['szx'], rtol=1e-4, atol=1e-4)
     assert_allclose(res['szy'], expected['szy'], rtol=1e-4, atol=1e-4)