def test_ahnet_shape(self, input_param, input_shape, expected_shape, fcn_input_param): net = AHNet(**input_param).to(device) net2d = FCN(**fcn_input_param).to(device) net.copy_from(net2d) with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape)
def test_script(self): # test 2D network net = AHNet(spatial_dims=2, out_channels=2) test_data = torch.randn(1, 1, 128, 64) test_script_save(net, test_data) # test 3D network net = AHNet(spatial_dims=3, out_channels=2, psp_block_num=0, upsample_mode="nearest") test_data = torch.randn(1, 1, 32, 32, 64) test_script_save(net, test_data)
def test_initialize_pretrained(self): net = AHNet( spatial_dims=3, upsample_mode="transpose", in_channels=2, out_channels=3, pretrained=True, progress=True, ) input_data = torch.randn(2, 2, 128, 128, 64) with torch.no_grad(): result = net.forward(input_data) self.assertEqual(result.shape, (2, 3, 128, 128, 64))
def test_initialize_pretrained(self): net = AHNet( spatial_dims=3, upsample_mode="transpose", in_channels=2, out_channels=3, psp_block_num=2, pretrained=True, progress=True, ).to(device) input_data = torch.randn(2, 2, 32, 32, 64).to(device) with eval_mode(net): result = net.forward(input_data) self.assertEqual(result.shape, (2, 3, 32, 32, 64))
def test_ahnet_shape(self, input_param, input_data, expected_shape, fcn_input_param): net = AHNet(**input_param) net2d = FCN(**fcn_input_param) net.copy_from(net2d) net.eval() with torch.no_grad(): result = net.forward(input_data) self.assertEqual(result.shape, expected_shape)
def test_fcn_shape(self, input_param, input_data, expected_shape): net = AHNet(**input_param) net.eval() with torch.no_grad(): result = net.forward(input_data) self.assertEqual(result.shape, expected_shape)
def test_script(self): net = AHNet(spatial_dims=3, out_channels=2) test_data = torch.randn(1, 1, 128, 128, 64) out_orig, out_reloaded = test_script_save(net, test_data) assert torch.allclose(out_orig, out_reloaded)
def test_ahnet_shape(self, input_param, input_shape, expected_shape): net = AHNet(**input_param) net.eval() with torch.no_grad(): result = net.forward(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape)
def test_ahnet_shape_3d(self, input_param, input_shape, expected_shape): net = AHNet(**input_param).to(device) with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape)
def test_script(self): net = AHNet(spatial_dims=2, out_channels=2) test_data = torch.randn(1, 1, 128, 64) test_script_save(net, test_data)