def test_ahnet_shape(self, input_param, input_shape, expected_shape, fcn_input_param): net = AHNet(**input_param).to(device) net2d = FCN(**fcn_input_param).to(device) net.copy_from(net2d) with eval_mode(net): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape)
def test_ahnet_shape(self, input_param, input_data, expected_shape, fcn_input_param): net = AHNet(**input_param) net2d = FCN(**fcn_input_param) net.copy_from(net2d) net.eval() with torch.no_grad(): result = net.forward(input_data) self.assertEqual(result.shape, expected_shape)