예제 #1
0
 def test_2d(self):
     sample = self.make_2d(self.sample)
     transform = RandomFlip(axes=(1, 2), flip_probability=1)
     transformed = transform(sample)
     assert_array_equal(
         sample.t1.data.numpy()[:, :, ::-1, ::-1],
         transformed.t1.data.numpy())
예제 #2
0
 def test_apply_transform_to_file(self):
     transform = RandomFlip()
     apply_transform_to_file(
         self.get_image_path('input'),
         transform,
         self.get_image_path('output'),
         verbose=True,
     )
예제 #3
0
 def test_2d(self):
     sample = self.make_2d(self.sample)
     transform = RandomFlip(axes=(1, 2), flip_probability=1)
     transformed = transform(sample)
     self.assertTensorEqual(
         sample.t1.data.numpy()[..., ::-1, ::-1],
         transformed.t1.data.numpy(),
     )
예제 #4
0
 def test_anatomical_axis(self):
     transform = RandomFlip(axes=['i'], flip_probability=1)
     tensor = torch.rand(1, 2, 3, 4)
     transformed = transform(tensor)
     self.assertTensorEqual(
         tensor.numpy()[..., ::-1],
         transformed.numpy(),
     )
예제 #5
0
 def test_tensor_flip(self):
     sample_input = torch.ones((4, 30, 30, 30))
     RandomFlip()(sample_input)
예제 #6
0
 def test_history(self):
     transformed = RandomFlip()(self.sample)
     self.assertIs(len(transformed.history), 1)
예제 #7
0
 def test_no_sample(self):
     with tempfile.NamedTemporaryFile() as f:
         input_dict = {'image': ScalarImage(f.name)}
         subject = Subject(input_dict)
         with self.assertRaises(RuntimeError):
             RandomFlip()(subject)
예제 #8
0
 def test_wrong_flip_probability_type(self):
     with self.assertRaises(ValueError):
         RandomFlip(flip_probability='wrong')
예제 #9
0
 def test_wrong_axes_type(self):
     with self.assertRaises(ValueError):
         RandomFlip(axes=None)
예제 #10
0
 def test_out_of_range_axis_in_tuple(self):
     with self.assertRaises(ValueError):
         RandomFlip(axes=(0, -1, 2))
예제 #11
0
 def test_out_of_range_axis(self):
     with self.assertRaises(ValueError):
         RandomFlip(axes=3)