def test_norm_min_max_transform(self): trafo = NormMinMax(per_channel=False) self.assertTrue(chech_data_preservation(trafo, self.batch_dict)) trafo = NormMinMax(per_channel=True) self.assertTrue(chech_data_preservation(trafo, self.batch_dict)) outp = trafo(**self.batch_dict) self.assertTrue(isclose(outp["data"].min().item(), 0.0, abs_tol=1e-6)) self.assertTrue(isclose(outp["data"].max().item(), 1.0, abs_tol=1e-6))
def test_norm_zero_mean_transform(self): trafo = NormZeroMeanUnitStd(per_channel=False) self.assertTrue(chech_data_preservation(trafo, self.batch_dict)) trafo = NormZeroMeanUnitStd(per_channel=True) self.assertTrue(chech_data_preservation(trafo, self.batch_dict)) outp = trafo(**self.batch_dict) self.assertTrue(isclose(outp["data"].mean().item(), 0.0, abs_tol=1e-6)) self.assertTrue(isclose(outp["data"].std().item(), 1.0, abs_tol=1e-6))
def test_norm_std_transform(self): mean = self.batch_dict["data"].mean().item() std = self.batch_dict["data"].std().item() trafo = NormMeanStd(mean=mean, std=std, per_channel=False) self.assertTrue(chech_data_preservation(trafo, self.batch_dict)) trafo = NormMeanStd(mean=mean, std=std, per_channel=True) self.assertTrue(chech_data_preservation(trafo, self.batch_dict)) outp = trafo(**self.batch_dict) self.assertTrue(isclose(outp["data"].mean().item(), 0.0, abs_tol=1e-6)) self.assertTrue(isclose(outp["data"].std().item(), 1.0, abs_tol=1e-6))
def test_random_scale_value(self): trafo = RandomScaleValue(DiscreteParameter((2, ))) self.assertTrue(chech_data_preservation(trafo, self.batch_dict)) outp = trafo(**self.batch_dict) expected_out = self.batch_dict["data"] * 2.0 self.assertTrue((outp["data"] == expected_out).all())
def test_clamp_transform(self): trafo = Clamp(0, 1) self.assertTrue(chech_data_preservation(trafo, self.batch_dict)) outp = trafo(**self.batch_dict) self.assertTrue((outp["data"] == torch.ones_like(outp["data"])).all())
def test_mirror_transform(self): trafo = Mirror((0, 1)) outp = trafo(**self.batch_dict) self.assertTrue(outp["data"][0, 0].allclose( torch.tensor([[9, 8, 7], [6, 5, 4], [3, 2, 1]]).float())) self.assertTrue(chech_data_preservation(trafo, self.batch_dict))
def test_gamma_transform_scalar(self): trafo = GammaCorrection(gamma=2) self.assertTrue(chech_data_preservation(trafo, self.batch_dict)) trafo = GammaCorrection(gamma=2) outp = trafo(**self.batch_dict) expected_out = self.batch_dict["data"].pow(2) self.assertTrue((outp["data"] == expected_out).all())
def test_random_scale_value(self): trafo = RandomScaleValue("random") self.assertTrue(chech_data_preservation(trafo, self.batch_dict)) random.seed(0) rand_val = random.random() random.seed(0) outp = trafo(**self.batch_dict) expected_out = self.batch_dict["data"] * rand_val self.assertTrue((outp["data"] == expected_out).all()) self.assertEqual(trafo.random_mode, "random")
def test_rot90_transform(self): random.seed(0) trafo = Rot90((0, 1), prob=1, num_rots=(1, )) outp = trafo(**self.batch_dict) self.assertTrue((outp["data"][0, 0] == torch.tensor([[3, 6, 9], [2, 5, 8], [1, 4, 7]])).all()) self.assertTrue(chech_data_preservation(trafo, self.batch_dict)) trafo = Rot90((0, 1), prob=0) data_orig = self.batch_dict["data"].clone() outp = trafo(**self.batch_dict) self.assertTrue((outp["data"] == data_orig).all())
def test_mirror_transform(self): trafo = Mirror((0, 1), prob=1) outp = trafo(**self.batch_dict) self.assertTrue((outp["data"][0, 0] == torch.tensor([[9, 8, 7], [6, 5, 4], [3, 2, 1]])).all()) self.assertTrue(chech_data_preservation(trafo, self.batch_dict)) trafo = Mirror((0, 1), prob=0) data_orig = self.batch_dict["data"].clone() outp = trafo(**self.batch_dict) self.assertTrue((outp["data"] == data_orig).all())
def test_gaussian_noise_transform(self): trafo = GaussianNoise(mean=75, std=1) self.assertTrue(chech_data_preservation(trafo, self.batch_dict)) self.check_noise_distance(trafo)
def test_expoential_noise_transform(self): trafo = ExponentialNoise(lambd=0.0001) self.assertTrue(chech_data_preservation(trafo, self.batch_dict)) self.check_noise_distance(trafo)
def test_noise_transform(self): trafo = Noise("normal", mean=75, std=1) self.assertTrue(chech_data_preservation(trafo, self.batch_dict)) self.check_noise_distance(trafo)