def test_random_equalize(self, device, dtype): f = RandomEqualize3D(p=1.0, return_transform=True) f1 = RandomEqualize3D(p=0.0, return_transform=True) f2 = RandomEqualize3D(p=1.0) f3 = RandomEqualize3D(p=0.0) bs, channels, depth, height, width = 1, 3, 6, 10, 10 inputs3d = self.build_input(channels, depth, height, width, bs, device=device, dtype=dtype) row_expected = torch.tensor( [0.0000, 0.11764, 0.2353, 0.3529, 0.4706, 0.5882, 0.7059, 0.8235, 0.9412, 1.0000], device=device, dtype=dtype, ) expected = self.build_input(channels, depth, height, width, bs=1, row=row_expected, device=device, dtype=dtype) identity = kornia.eye_like(4, expected) assert_allclose(f(inputs3d)[0], expected, rtol=1e-4, atol=1e-4) assert_allclose(f(inputs3d)[1], identity, rtol=1e-4, atol=1e-4) assert_allclose(f1(inputs3d)[0], inputs3d, rtol=1e-4, atol=1e-4) assert_allclose(f1(inputs3d)[1], identity, rtol=1e-4, atol=1e-4) assert_allclose(f2(inputs3d), expected, rtol=1e-4, atol=1e-4) assert_allclose(f3(inputs3d), inputs3d, rtol=1e-4, atol=1e-4)
def test_batch_random_equalize(self, device, dtype): f = RandomEqualize3D(p=1.0) f1 = RandomEqualize3D(p=0.0) bs, channels, depth, height, width = 2, 3, 6, 10, 10 inputs3d = self.build_input(channels, depth, height, width, bs, device=device, dtype=dtype) row_expected = torch.tensor([ 0.0000, 0.11764, 0.2353, 0.3529, 0.4706, 0.5882, 0.7059, 0.8235, 0.9412, 1.0000 ]) expected = self.build_input(channels, depth, height, width, bs, row=row_expected, device=device, dtype=dtype) identity = kornia.eye_like(4, expected) # 2 x 4 x 4 assert_close(f(inputs3d), expected, rtol=1e-4, atol=1e-4) assert_close(f.transform_matrix, identity, rtol=1e-4, atol=1e-4) assert_close(f1(inputs3d), inputs3d, rtol=1e-4, atol=1e-4) assert_close(f1.transform_matrix, identity, rtol=1e-4, atol=1e-4)
def test_gradcheck(self, device, dtype): torch.manual_seed(0) # for random reproductibility inputs3d = torch.rand((3, 3, 3), device=device, dtype=dtype) # 3 x 3 x 3 inputs3d = utils.tensor_to_gradcheck_var(inputs3d) # to var assert gradcheck(RandomEqualize3D(p=0.5), (inputs3d, ), raise_exception=True)
def test_same_on_batch(self, device, dtype): f = RandomEqualize3D(p=0.5, same_on_batch=True) input = torch.eye(4, device=device, dtype=dtype) input = input.unsqueeze(dim=0).unsqueeze(dim=0).repeat(2, 1, 2, 1, 1) res = f(input) assert (res[0] == res[1]).all()
def test_smoke(self, device, dtype): f = RandomEqualize3D(p=0.5) repr = "RandomEqualize3D(p=0.5, p_batch=1.0, same_on_batch=False, return_transform=False)" assert str(f) == repr