def _test_roialign_rotated_allclose(device, dtype): if not torch.cuda.is_available() and device == 'cuda': pytest.skip('unittest does not support GPU yet.') try: from mmcv.ops import roi_align_rotated except ModuleNotFoundError: pytest.skip('test requires compilation') pool_h = 2 pool_w = 2 spatial_scale = 1.0 sampling_ratio = 2 for case, output in zip(inputs, outputs): np_input = np.array(case[0]) np_rois = np.array(case[1]) np_output = np.array(output[0]) np_grad = np.array(output[1]) x = torch.tensor( np_input, dtype=dtype, device=device, requires_grad=True) rois = torch.tensor(np_rois, dtype=dtype, device=device) output = roi_align_rotated(x, rois, (pool_h, pool_w), spatial_scale, sampling_ratio, True) output.backward(torch.ones_like(output)) assert np.allclose( output.data.type(torch.float).cpu().numpy(), np_output, atol=1e-3) assert np.allclose( x.grad.data.type(torch.float).cpu().numpy(), np_grad, atol=1e-3)
def warpped_function(torch_input, torch_rois): return roi_align_rotated(torch_input, torch_rois, (pool_w, pool_h), spatial_scale, sampling_ratio, True, False)
def test_roialign_rotated(): if torch.__version__ == 'parrots': pytest.skip('onnx is not supported in parrots directly') try: from mmcv.ops import roi_align_rotated from mmcv.ops import get_onnxruntime_op_path except (ImportError, ModuleNotFoundError): pytest.skip('roi_align_aligned op is not successfully compiled') ort_custom_op_path = get_onnxruntime_op_path() if not os.path.exists(ort_custom_op_path): pytest.skip('custom ops for onnxruntime are not compiled.') # roi align config pool_h = 2 pool_w = 2 spatial_scale = 1.0 sampling_ratio = 2 inputs = [([[[[1., 2.], [3., 4.]]]], [[0., 0.5, 0.5, 1., 1., 0]]), ([[[[1., 2.], [3., 4.]]]], [[0., 0.5, 0.5, 1., 1., np.pi / 2]]), ([[[[1., 2.], [3., 4.]], [[4., 3.], [2., 1.]]]], [[0., 0.5, 0.5, 1., 1., 0]]), ([[[[1., 2., 5., 6.], [3., 4., 7., 8.], [9., 10., 13., 14.], [11., 12., 15., 16.]]]], [[0., 1.5, 1.5, 3., 3., 0]]), ([[[[1., 2., 5., 6.], [3., 4., 7., 8.], [9., 10., 13., 14.], [11., 12., 15., 16.]]]], [[0., 1.5, 1.5, 3., 3., np.pi / 2]])] def warpped_function(torch_input, torch_rois): return roi_align_rotated(torch_input, torch_rois, (pool_w, pool_h), spatial_scale, sampling_ratio, True, False) for case in inputs: np_input = np.array(case[0], dtype=np.float32) np_rois = np.array(case[1], dtype=np.float32) input = torch.from_numpy(np_input) rois = torch.from_numpy(np_rois) # compute pytorch_output with torch.no_grad(): pytorch_output = roi_align_rotated(input, rois, (pool_w, pool_h), spatial_scale, sampling_ratio, True, False) # export and load onnx model wrapped_model = WrapFunction(warpped_function) with torch.no_grad(): torch.onnx.export(wrapped_model, (input, rois), onnx_file, export_params=True, keep_initializers_as_inputs=True, input_names=['features', 'rois'], opset_version=11) onnx_model = onnx.load(onnx_file) session_options = rt.SessionOptions() if os.path.exists(ort_custom_op_path): session_options.register_custom_ops_library(ort_custom_op_path) # compute onnx_output input_all = [node.name for node in onnx_model.graph.input] input_initializer = [ node.name for node in onnx_model.graph.initializer ] net_feed_input = list(set(input_all) - set(input_initializer)) assert (len(net_feed_input) == 2) sess = rt.InferenceSession(onnx_file, session_options) onnx_output = sess.run(None, { 'features': input.detach().numpy(), 'rois': rois.detach().numpy() }) onnx_output = onnx_output[0] # allclose os.remove(onnx_file) assert np.allclose(pytorch_output, onnx_output, atol=1e-3)