def test_valid_param_combinations(self, batch_size, depth, height, width,
                                   size, device, dtype):
     center_crop_generator3d(batch_size=batch_size,
                             depth=depth,
                             height=height,
                             width=width,
                             size=size)
 def test_invalid_param_combinations(self, depth, height, width, size,
                                     device, dtype):
     with pytest.raises(Exception):
         center_crop_generator3d(batch_size=2,
                                 depth=depth,
                                 height=height,
                                 width=width,
                                 size=size)
 def test_random_gen(self, device, dtype):
     torch.manual_seed(42)
     res = center_crop_generator3d(batch_size=2,
                                   depth=200,
                                   height=200,
                                   width=200,
                                   size=(120, 150, 100))
     expected = dict(
         src=torch.tensor([[[50, 25, 40], [149, 25, 40], [149, 174, 40],
                            [50, 174, 40], [50, 25, 159], [149, 25, 159],
                            [149, 174, 159], [50, 174, 159]]],
                          device=device,
                          dtype=torch.long).repeat(2, 1, 1),
         dst=torch.tensor(
             [[[0, 0, 0], [99, 0, 0], [99, 149, 0], [0, 149, 0],
               [0, 0, 119], [99, 0, 119], [99, 149, 119], [0, 149, 119]]],
             device=device,
             dtype=torch.long).repeat(2, 1, 1),
     )
     assert res.keys() == expected.keys()
     assert_allclose(res['src'].to(device=device),
                     expected['src'],
                     atol=1e-4,
                     rtol=1e-4)
     assert_allclose(res['dst'].to(device=device),
                     expected['dst'],
                     atol=1e-4,
                     rtol=1e-4)
示例#4
0
 def generate_parameters(self, batch_shape: Size) -> Dict[str, Tensor]:
     return rg.center_crop_generator3d(batch_shape[0],
                                       batch_shape[-3],
                                       batch_shape[-2],
                                       batch_shape[-1],
                                       self.flags["size"],
                                       device=self.device)