def test_with_zero_gamma(self): transform = RandomGamma(log_gamma=0) transformed = transform(self.sample_subject) self.assertTensorAlmostEqual(self.sample_subject.t1.data, transformed.t1.data)
def test_wrong_gamma_type(self): with self.assertRaises(ValueError): RandomGamma(log_gamma='wrong')
def test_negative_values(self): with self.assertWarns(UserWarning): RandomGamma()(torch.rand(1, 3, 3, 3) - 1)
def test_with_low_gamma(self): transform = RandomGamma(log_gamma=(-100, -100)) transformed = transform(self.sample_subject) self.assertTensorAlmostEqual(self.sample_subject.t1.data > 0, transformed.t1.data)
def test_with_non_zero_gamma(self): transform = RandomGamma(log_gamma=(0.1, 0.3)) transformed = transform(self.sample_subject) self.assertTensorNotEqual(self.sample_subject.t1.data, transformed.t1.data)
def test_with_low_gamma(self): transform = RandomGamma(log_gamma=(-100, -100)) tensor = self.get_random_tensor_zero_one() transformed = transform(tensor) self.assertTensorAlmostEqual(tensor > 0, transformed)
def test_with_non_zero_gamma(self): transform = RandomGamma(log_gamma=(0.1, 0.3)) tensor = self.get_random_tensor_zero_one() transformed = transform(tensor) self.assertTensorNotEqual(tensor, transformed)
def test_with_high_gamma(self): transform = RandomGamma(log_gamma=(100, 100)) transformed = transform(self.sample) self.assertTensorAlmostEqual(self.sample.t1.data == 1, transformed.t1.data)