def test_batch_random_vflip(self, device): f = RandomVerticalFlip3D(p=1.0, return_transform=True) f1 = RandomVerticalFlip3D(p=0.0, return_transform=True) input = torch.tensor([[[[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 1.0]]]]]) # 1 x 1 x 1 x 3 x 3 input = input.to(device) expected = torch.tensor([[[[[0.0, 1.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]]]) # 1 x 1 x 1 x 3 x 3 expected = expected.to(device) expected_transform = torch.tensor([[[1.0, 0.0, 0.0, 0.0], [0.0, -1.0, 0.0, 2.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]]) # 1 x 4 x 4 expected_transform = expected_transform.to(device) identity = torch.tensor([[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]]) # 1 x 4 x 4 identity = identity.to(device) input = input.repeat(5, 3, 1, 1, 1) # 5 x 3 x 3 x 3 x 3 expected = expected.repeat(5, 3, 1, 1, 1) # 5 x 3 x 3 x 3 x 3 expected_transform = expected_transform.repeat(5, 1, 1) # 5 x 4 x 4 identity = identity.repeat(5, 1, 1) # 5 x 4 x 4 assert_allclose(f(input)[0], expected) assert_allclose(f(input)[1], expected_transform) assert_allclose(f1(input)[0], input) assert_allclose(f1(input)[1], identity)
def test_sequential(self, device): f = nn.Sequential( RandomVerticalFlip3D(p=1.0, return_transform=True), RandomVerticalFlip3D(p=1.0, return_transform=True), ) f1 = nn.Sequential( RandomVerticalFlip3D(p=1.0, return_transform=True), RandomVerticalFlip3D(p=1.0), ) input = torch.tensor([[[[[0., 0., 0.], [0., 0., 0.], [0., 1., 1.]]]]]) # 1 x 1 x 1 x 4 x 4 input = input.to(device) expected_transform = torch.tensor([[[1., 0., 0., 0.], [0., -1., 0., 2.], [0., 0., 1., 0.], [0., 0., 0., 1.]]]) # 1 x 4 x 4 expected_transform = expected_transform.to(device) expected_transform_1 = expected_transform @ expected_transform assert_allclose(f(input)[0], input.squeeze()) assert_allclose(f(input)[1], expected_transform_1) assert_allclose(f1(input)[0], input.squeeze()) assert_allclose(f1(input)[1], expected_transform)
def test_random_vflip(self, device, dtype): f = RandomVerticalFlip3D(p=1.0, return_transform=True) f1 = RandomVerticalFlip3D(p=0., return_transform=True) f2 = RandomVerticalFlip3D(p=1.) f3 = RandomVerticalFlip3D(p=0.) input = torch.tensor([[[[[0., 0., 0.], [0., 0., 0.], [0., 1., 1.]], [[0., 0., 0.], [0., 0., 0.], [0., 1., 1.]]]]], device=device, dtype=dtype) # 1 x 1 x 2 x 3 x 3 expected = torch.tensor( [[[[[0., 1., 1.], [0., 0., 0.], [0., 0., 0.]], [[0., 1., 1.], [0., 0., 0.], [0., 0., 0.]]]]], device=device, dtype=dtype) # 1 x 1 x 2 x 3 x 3 expected_transform = torch.tensor( [[[1., 0., 0., 0.], [0., -1., 0., 2.], [0., 0., 1., 0.], [0., 0., 0., 1.]]], device=device, dtype=dtype) # 4 x 4 identity = torch.tensor([[[1., 0., 0., 0.], [0., 1., 0., 0.], [0., 0., 1., 0.], [0., 0., 0., 1.]]], device=device, dtype=dtype) # 1 x 4 x 4 assert_allclose(f(input)[0], expected) assert_allclose(f(input)[1], expected_transform) assert_allclose(f1(input)[0], input) assert_allclose(f1(input)[1], identity) assert_allclose(f2(input), expected) assert_allclose(f3(input), input)
def test_random_vflip(self, device): f = RandomVerticalFlip3D(p=1.0, return_transform=True) f1 = RandomVerticalFlip3D(p=0., return_transform=True) f2 = RandomVerticalFlip3D(p=1.) f3 = RandomVerticalFlip3D(p=0.) input = torch.tensor([[[0., 0., 0.], [0., 0., 0.], [0., 1., 1.]], [[0., 0., 0.], [0., 0., 0.], [0., 1., 1.]]]) # 2 x 3 x 3 input = input.to(device) expected = torch.tensor([[[0., 1., 1.], [0., 0., 0.], [0., 0., 0.]], [[0., 1., 1.], [0., 0., 0.], [0., 0., 0.]]]) # 2 x 3 x 3 expected = expected.to(device) expected_transform = torch.tensor([[1., 0., 0., 0.], [0., -1., 0., 2.], [0., 0., 1., 0.], [0., 0., 0., 1.]]) # 4 x 4 expected_transform = expected_transform.to(device) identity = torch.tensor([[1., 0., 0., 0.], [0., 1., 0., 0.], [0., 0., 1., 0.], [0., 0., 0., 1.]]) # 4 x 4 identity = identity.to(device) assert_allclose(f(input)[0], expected) assert_allclose(f(input)[1], expected_transform) assert_allclose(f1(input)[0], input) assert_allclose(f1(input)[1], identity) assert_allclose(f2(input), expected) assert_allclose(f3(input), input)
def test_sequential(self, device): f = AugmentationSequential(RandomVerticalFlip3D(p=1.0), RandomVerticalFlip3D(p=1.0)) input = torch.tensor([[[[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 1.0]]]]]) # 1 x 1 x 1 x 4 x 4 input = input.to(device) expected_transform = torch.tensor([[[1.0, 0.0, 0.0, 0.0], [0.0, -1.0, 0.0, 2.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]]) # 1 x 4 x 4 expected_transform = expected_transform.to(device) expected_transform_1 = expected_transform @ expected_transform assert_close(f(input), input) assert_close(f.transform_matrix, expected_transform_1)
def test_random_vflip(self, device, dtype): f = RandomVerticalFlip3D(p=1.0) f1 = RandomVerticalFlip3D(p=0.0) input = torch.tensor( [[[ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 1.0]], ]]], device=device, dtype=dtype, ) # 1 x 1 x 2 x 3 x 3 expected = torch.tensor( [[[ [[0.0, 1.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], [[0.0, 1.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], ]]], device=device, dtype=dtype, ) # 1 x 1 x 2 x 3 x 3 expected_transform = torch.tensor( [[[1.0, 0.0, 0.0, 0.0], [0.0, -1.0, 0.0, 2.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]], device=device, dtype=dtype, ) # 4 x 4 identity = torch.tensor( [[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]], device=device, dtype=dtype, ) # 1 x 4 x 4 assert_close(f(input), expected) assert_close(f.transform_matrix, expected_transform) assert_close(f1(input), input) assert_close(f1.transform_matrix, identity)
def test_gradcheck(self, device): input = torch.rand((1, 3, 3)).to(device) # 4 x 4 input = utils.tensor_to_gradcheck_var(input) # to var assert gradcheck(RandomVerticalFlip3D(p=1.), (input, ), raise_exception=True)
def test_same_on_batch(self, device): f = RandomVerticalFlip3D(p=0.5, same_on_batch=True) input = torch.eye(3).unsqueeze(dim=0).unsqueeze(dim=0).repeat(2, 1, 1, 1, 1) res = f(input) assert (res[0] == res[1]).all()
def test_smoke(self): f = RandomVerticalFlip3D(0.5) repr = "RandomVerticalFlip3D(p=0.5, p_batch=1.0, same_on_batch=False, return_transform=0.5)" assert str(f) == repr
def smoke_test(self, device): f = RandomVerticalFlip3D(0.5) repr = "RandomVerticalFlip3D(p=0.5, return_transform=False)" assert str(f) == repr