Exemple #1
0
 def test_ahnet_shape(self, input_param, input_shape, expected_shape, fcn_input_param):
     net = AHNet(**input_param).to(device)
     net2d = FCN(**fcn_input_param).to(device)
     net.copy_from(net2d)
     with eval_mode(net):
         result = net.forward(torch.randn(input_shape).to(device))
         self.assertEqual(result.shape, expected_shape)
Exemple #2
0
 def test_script(self):
     # test 2D network
     net = AHNet(spatial_dims=2, out_channels=2)
     test_data = torch.randn(1, 1, 128, 64)
     test_script_save(net, test_data)
     # test 3D network
     net = AHNet(spatial_dims=3, out_channels=2, psp_block_num=0, upsample_mode="nearest")
     test_data = torch.randn(1, 1, 32, 32, 64)
     test_script_save(net, test_data)
Exemple #3
0
 def test_initialize_pretrained(self):
     net = AHNet(
         spatial_dims=3,
         upsample_mode="transpose",
         in_channels=2,
         out_channels=3,
         pretrained=True,
         progress=True,
     )
     input_data = torch.randn(2, 2, 128, 128, 64)
     with torch.no_grad():
         result = net.forward(input_data)
         self.assertEqual(result.shape, (2, 3, 128, 128, 64))
Exemple #4
0
 def test_initialize_pretrained(self):
     net = AHNet(
         spatial_dims=3,
         upsample_mode="transpose",
         in_channels=2,
         out_channels=3,
         psp_block_num=2,
         pretrained=True,
         progress=True,
     ).to(device)
     input_data = torch.randn(2, 2, 32, 32, 64).to(device)
     with eval_mode(net):
         result = net.forward(input_data)
         self.assertEqual(result.shape, (2, 3, 32, 32, 64))
Exemple #5
0
 def test_ahnet_shape(self, input_param, input_data, expected_shape,
                      fcn_input_param):
     net = AHNet(**input_param)
     net2d = FCN(**fcn_input_param)
     net.copy_from(net2d)
     net.eval()
     with torch.no_grad():
         result = net.forward(input_data)
         self.assertEqual(result.shape, expected_shape)
Exemple #6
0
 def test_fcn_shape(self, input_param, input_data, expected_shape):
     net = AHNet(**input_param)
     net.eval()
     with torch.no_grad():
         result = net.forward(input_data)
         self.assertEqual(result.shape, expected_shape)
Exemple #7
0
 def test_script(self):
     net = AHNet(spatial_dims=3, out_channels=2)
     test_data = torch.randn(1, 1, 128, 128, 64)
     out_orig, out_reloaded = test_script_save(net, test_data)
     assert torch.allclose(out_orig, out_reloaded)
Exemple #8
0
 def test_ahnet_shape(self, input_param, input_shape, expected_shape):
     net = AHNet(**input_param)
     net.eval()
     with torch.no_grad():
         result = net.forward(torch.randn(input_shape))
         self.assertEqual(result.shape, expected_shape)
Exemple #9
0
 def test_ahnet_shape_3d(self, input_param, input_shape, expected_shape):
     net = AHNet(**input_param).to(device)
     with eval_mode(net):
         result = net.forward(torch.randn(input_shape).to(device))
         self.assertEqual(result.shape, expected_shape)
Exemple #10
0
 def test_script(self):
     net = AHNet(spatial_dims=2, out_channels=2)
     test_data = torch.randn(1, 1, 128, 64)
     test_script_save(net, test_data)