def test_invalid_param_combinations(self, input_size, size, resize_to, device, dtype): batch_size = 2 random_crop_generator3d(batch_size=batch_size, input_size=input_size, size=size.to(device=device, dtype=dtype) if isinstance(size, torch.Tensor) else size, resize_to=resize_to)
def test_valid_param_combinations(self, batch_size, input_size, size, resize_to, same_on_batch, device, dtype): if isinstance(size, torch.Tensor): size = size.repeat(batch_size, 1).to(device=device, dtype=dtype) random_crop_generator3d(batch_size=batch_size, input_size=input_size, size=size.to(device=device, dtype=dtype) if isinstance(size, torch.Tensor) else size, resize_to=resize_to, same_on_batch=same_on_batch)
def test_same_on_batch(self, device, dtype): torch.manual_seed(42) res = random_crop_generator3d(batch_size=2, input_size=(200, 200, 200), size=torch.tensor( [[50, 60, 70], [50, 60, 70]], device=device, dtype=dtype), resize_to=(100, 100, 100), same_on_batch=True) expected = dict( src=torch.tensor([[[7., 8., 18.], [76., 8., 18.], [76., 67., 18.], [7., 67., 18.], [7., 8., 67.], [76., 8., 67.], [76., 67., 67.], [7., 67., 67.]], [[7., 8., 18.], [76., 8., 18.], [76., 67., 18.], [7., 67., 18.], [7., 8., 67.], [76., 8., 67.], [76., 67., 67.], [7., 67., 67.]]], device=device, dtype=dtype), dst=torch.tensor([[[0., 0., 0.], [99., 0., 0.], [99., 99., 0.], [0., 99., 0.], [0., 0., 99.], [99., 0., 99.], [99., 99., 99.], [0., 99., 99.]], [[0., 0., 0.], [99., 0., 0.], [99., 99., 0.], [0., 99., 0.], [0., 0., 99.], [99., 0., 99.], [99., 99., 99.], [0., 99., 99.]]], device=device, dtype=dtype), ) assert res.keys() == expected.keys() assert_allclose(res['src'], expected['src'], atol=1e-4, rtol=1e-4) assert_allclose(res['dst'], expected['dst'], atol=1e-4, rtol=1e-4)
def test_random_gen(self, device, dtype): torch.manual_seed(42) res = random_crop_generator3d(batch_size=2, input_size=(200, 200, 200), size=torch.tensor( [[50, 60, 70], [50, 60, 70]], device=device, dtype=dtype), resize_to=(100, 100, 100)) expected = dict( src=torch.tensor( [[[115, 53, 58], [184, 53, 58], [184, 112, 58], [115, 112, 58], [115, 53, 107], [184, 53, 107], [184, 112, 107], [115, 112, 107]], [[119, 135, 90], [188, 135, 90], [188, 194, 90], [119, 194, 90], [119, 135, 139], [188, 135, 139], [188, 194, 139], [119, 194, 139]]], device=device, dtype=dtype), dst=torch.tensor( [[[0, 0, 0], [99, 0, 0], [99, 99, 0], [0, 99, 0], [0, 0, 99], [99, 0, 99], [99, 99, 99], [0, 99, 99]], [[0, 0, 0], [99, 0, 0], [99, 99, 0], [0, 99, 0], [0, 0, 99], [99, 0, 99], [99, 99, 99], [0, 99, 99]]], device=device, dtype=dtype), ) assert res.keys() == expected.keys() assert_allclose(res['src'], expected['src'], atol=1e-4, rtol=1e-4) assert_allclose(res['dst'], expected['dst'], atol=1e-4, rtol=1e-4)
def test_same_on_batch(self, device, dtype): torch.manual_seed(42) res = random_crop_generator3d(batch_size=2, input_size=(200, 200, 200), size=torch.tensor( [[50, 60, 70], [50, 60, 70]], device=device, dtype=dtype), resize_to=(100, 100, 100), same_on_batch=True) expected = dict( src=torch.tensor( [[[115, 129, 57], [184, 129, 57], [184, 188, 57], [115, 188, 57], [115, 129, 106], [184, 129, 106], [184, 188, 106], [115, 188, 106]], [[115, 129, 57], [184, 129, 57], [184, 188, 57], [115, 188, 57], [115, 129, 106], [184, 129, 106], [184, 188, 106], [115, 188, 106]]], device=device, dtype=torch.long), dst=torch.tensor( [[[0, 0, 0], [99, 0, 0], [99, 99, 0], [0, 99, 0], [0, 0, 99], [99, 0, 99], [99, 99, 99], [0, 99, 99]], [[0, 0, 0], [99, 0, 0], [99, 99, 0], [0, 99, 0], [0, 0, 99], [99, 0, 99], [99, 99, 99], [0, 99, 99]]], device=device, dtype=torch.long), ) assert res.keys() == expected.keys() assert_allclose(res['src'], expected['src'], atol=1e-4, rtol=1e-4) assert_allclose(res['dst'], expected['dst'], atol=1e-4, rtol=1e-4)