def test_roipool(): if torch.__version__ == 'parrots': pytest.skip('onnx is not supported in parrots directly') from mmcv.ops import roi_pool # roi pool config pool_h = 2 pool_w = 2 spatial_scale = 1.0 inputs = [([[[[1., 2.], [3., 4.]]]], [[0., 0., 0., 1., 1.]]), ([[[[1., 2.], [3., 4.]], [[4., 3.], [2., 1.]]]], [[0., 0., 0., 1., 1.]]), ([[[[1., 2., 5., 6.], [3., 4., 7., 8.], [9., 10., 13., 14.], [11., 12., 15., 16.]]]], [[0., 0., 0., 3., 3.]])] def warpped_function(torch_input, torch_rois): return roi_pool(torch_input, torch_rois, (pool_w, pool_h), spatial_scale) for case in inputs: np_input = np.array(case[0], dtype=np.float32) np_rois = np.array(case[1], dtype=np.float32) input = torch.from_numpy(np_input).cuda() rois = torch.from_numpy(np_rois).cuda() # compute pytorch_output with torch.no_grad(): pytorch_output = roi_pool(input, rois, (pool_w, pool_h), spatial_scale) pytorch_output = pytorch_output.cpu() # export and load onnx model wrapped_model = WrapFunction(warpped_function) with torch.no_grad(): torch.onnx.export(wrapped_model, (input, rois), onnx_file, export_params=True, keep_initializers_as_inputs=True, input_names=['input', 'rois'], opset_version=11) onnx_model = onnx.load(onnx_file) # compute onnx_output input_all = [node.name for node in onnx_model.graph.input] input_initializer = [ node.name for node in onnx_model.graph.initializer ] net_feed_input = list(set(input_all) - set(input_initializer)) assert (len(net_feed_input) == 2) sess = rt.InferenceSession(onnx_file) onnx_output = sess.run( None, { 'input': input.detach().cpu().numpy(), 'rois': rois.detach().cpu().numpy() }) onnx_output = onnx_output[0] # allclose assert np.allclose(pytorch_output, onnx_output, atol=1e-3)
def _test_roipool_allclose(self, dtype=torch.float): if not torch.cuda.is_available(): return from mmcv.ops import roi_pool pool_h = 2 pool_w = 2 spatial_scale = 1.0 for case, output in zip(inputs, outputs): np_input = np.array(case[0]) np_rois = np.array(case[1]) np_output = np.array(output[0]) np_grad = np.array(output[1]) x = torch.tensor(np_input, dtype=dtype, device='cuda', requires_grad=True) rois = torch.tensor(np_rois, dtype=dtype, device='cuda') output = roi_pool(x, rois, (pool_h, pool_w), spatial_scale) output.backward(torch.ones_like(output)) assert np.allclose(output.data.cpu().numpy(), np_output, 1e-3) assert np.allclose(x.grad.data.cpu().numpy(), np_grad, 1e-3)
def warpped_function(torch_input, torch_rois): return roi_pool(torch_input, torch_rois, (pool_w, pool_h), spatial_scale)