Ejemplo n.º 1
0
 def test_invalid_param_combinations(self, input_size, size, resize_to,
                                     device, dtype):
     batch_size = 2
     random_crop_generator3d(batch_size=batch_size,
                             input_size=input_size,
                             size=size.to(device=device, dtype=dtype)
                             if isinstance(size, torch.Tensor) else size,
                             resize_to=resize_to)
Ejemplo n.º 2
0
 def test_valid_param_combinations(self, batch_size, input_size, size,
                                   resize_to, same_on_batch, device, dtype):
     if isinstance(size, torch.Tensor):
         size = size.repeat(batch_size, 1).to(device=device, dtype=dtype)
     random_crop_generator3d(batch_size=batch_size,
                             input_size=input_size,
                             size=size.to(device=device, dtype=dtype)
                             if isinstance(size, torch.Tensor) else size,
                             resize_to=resize_to,
                             same_on_batch=same_on_batch)
Ejemplo n.º 3
0
 def test_same_on_batch(self, device, dtype):
     torch.manual_seed(42)
     res = random_crop_generator3d(batch_size=2,
                                   input_size=(200, 200, 200),
                                   size=torch.tensor(
                                       [[50, 60, 70], [50, 60, 70]],
                                       device=device,
                                       dtype=dtype),
                                   resize_to=(100, 100, 100),
                                   same_on_batch=True)
     expected = dict(
         src=torch.tensor([[[7., 8., 18.], [76., 8., 18.], [76., 67., 18.],
                            [7., 67., 18.], [7., 8., 67.], [76., 8., 67.],
                            [76., 67., 67.], [7., 67., 67.]],
                           [[7., 8., 18.], [76., 8., 18.], [76., 67., 18.],
                            [7., 67., 18.], [7., 8., 67.], [76., 8., 67.],
                            [76., 67., 67.], [7., 67., 67.]]],
                          device=device,
                          dtype=dtype),
         dst=torch.tensor([[[0., 0., 0.], [99., 0., 0.], [99., 99., 0.],
                            [0., 99., 0.], [0., 0., 99.], [99., 0., 99.],
                            [99., 99., 99.], [0., 99., 99.]],
                           [[0., 0., 0.], [99., 0., 0.], [99., 99., 0.],
                            [0., 99., 0.], [0., 0., 99.], [99., 0., 99.],
                            [99., 99., 99.], [0., 99., 99.]]],
                          device=device,
                          dtype=dtype),
     )
     assert res.keys() == expected.keys()
     assert_allclose(res['src'], expected['src'], atol=1e-4, rtol=1e-4)
     assert_allclose(res['dst'], expected['dst'], atol=1e-4, rtol=1e-4)
Ejemplo n.º 4
0
 def test_random_gen(self, device, dtype):
     torch.manual_seed(42)
     res = random_crop_generator3d(batch_size=2,
                                   input_size=(200, 200, 200),
                                   size=torch.tensor(
                                       [[50, 60, 70], [50, 60, 70]],
                                       device=device,
                                       dtype=dtype),
                                   resize_to=(100, 100, 100))
     expected = dict(
         src=torch.tensor(
             [[[115, 53, 58], [184, 53, 58], [184, 112, 58], [115, 112, 58],
               [115, 53, 107], [184, 53, 107], [184, 112, 107],
               [115, 112, 107]],
              [[119, 135, 90], [188, 135, 90], [188, 194, 90],
               [119, 194, 90], [119, 135, 139], [188, 135, 139],
               [188, 194, 139], [119, 194, 139]]],
             device=device,
             dtype=dtype),
         dst=torch.tensor(
             [[[0, 0, 0], [99, 0, 0], [99, 99, 0], [0, 99, 0], [0, 0, 99],
               [99, 0, 99], [99, 99, 99], [0, 99, 99]],
              [[0, 0, 0], [99, 0, 0], [99, 99, 0], [0, 99, 0], [0, 0, 99],
               [99, 0, 99], [99, 99, 99], [0, 99, 99]]],
             device=device,
             dtype=dtype),
     )
     assert res.keys() == expected.keys()
     assert_allclose(res['src'], expected['src'], atol=1e-4, rtol=1e-4)
     assert_allclose(res['dst'], expected['dst'], atol=1e-4, rtol=1e-4)
Ejemplo n.º 5
0
 def test_same_on_batch(self, device, dtype):
     torch.manual_seed(42)
     res = random_crop_generator3d(batch_size=2,
                                   input_size=(200, 200, 200),
                                   size=torch.tensor(
                                       [[50, 60, 70], [50, 60, 70]],
                                       device=device,
                                       dtype=dtype),
                                   resize_to=(100, 100, 100),
                                   same_on_batch=True)
     expected = dict(
         src=torch.tensor(
             [[[115, 129, 57], [184, 129, 57], [184, 188, 57],
               [115, 188, 57], [115, 129, 106], [184, 129, 106],
               [184, 188, 106], [115, 188, 106]],
              [[115, 129, 57], [184, 129, 57], [184, 188, 57],
               [115, 188, 57], [115, 129, 106], [184, 129, 106],
               [184, 188, 106], [115, 188, 106]]],
             device=device,
             dtype=torch.long),
         dst=torch.tensor(
             [[[0, 0, 0], [99, 0, 0], [99, 99, 0], [0, 99, 0], [0, 0, 99],
               [99, 0, 99], [99, 99, 99], [0, 99, 99]],
              [[0, 0, 0], [99, 0, 0], [99, 99, 0], [0, 99, 0], [0, 0, 99],
               [99, 0, 99], [99, 99, 99], [0, 99, 99]]],
             device=device,
             dtype=torch.long),
     )
     assert res.keys() == expected.keys()
     assert_allclose(res['src'], expected['src'], atol=1e-4, rtol=1e-4)
     assert_allclose(res['dst'], expected['dst'], atol=1e-4, rtol=1e-4)