def test_vae_shape(self, input_param, input_shape, expected_shape): net = SegResNetVAE(**input_param) with torch.no_grad(): result, _ = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape)
def test_script(self): input_param, input_shape, expected_shape = TEST_CASE_SEGRESNET_VAE[0] net = SegResNetVAE(**input_param) test_data = torch.randn(input_shape) test_script_save(net, test_data)
def test_script(self): input_param, input_shape, expected_shape = TEST_CASE_SEGRESNET_VAE[0] net = SegResNetVAE(**input_param) test_data = torch.randn(input_shape) out_orig, out_reloaded = test_script_save(net, test_data) assert torch.allclose(out_orig[0], out_reloaded[0])
def test_vae_shape(self, input_param, input_shape, expected_shape): net = SegResNetVAE(**input_param).to(device) with eval_mode(net): result, _ = net(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape)