コード例 #1
0
 def test_vae_shape(self, input_param, input_shape, expected_shape):
     net = SegResNetVAE(**input_param)
     with torch.no_grad():
         result, _ = net(torch.randn(input_shape))
         self.assertEqual(result.shape, expected_shape)
コード例 #2
0
 def test_script(self):
     input_param, input_shape, expected_shape = TEST_CASE_SEGRESNET_VAE[0]
     net = SegResNetVAE(**input_param)
     test_data = torch.randn(input_shape)
     test_script_save(net, test_data)
コード例 #3
0
 def test_script(self):
     input_param, input_shape, expected_shape = TEST_CASE_SEGRESNET_VAE[0]
     net = SegResNetVAE(**input_param)
     test_data = torch.randn(input_shape)
     out_orig, out_reloaded = test_script_save(net, test_data)
     assert torch.allclose(out_orig[0], out_reloaded[0])
コード例 #4
0
 def test_vae_shape(self, input_param, input_shape, expected_shape):
     net = SegResNetVAE(**input_param).to(device)
     with eval_mode(net):
         result, _ = net(torch.randn(input_shape).to(device))
         self.assertEqual(result.shape, expected_shape)