def _test_roialign_allclose(device, dtype): if not torch.cuda.is_available() and device == 'cuda': pytest.skip('test requires GPU') try: from mmcv.ops import roi_align except ModuleNotFoundError: pytest.skip('test requires compilation') pool_h = 2 pool_w = 2 spatial_scale = 1.0 sampling_ratio = 2 for case, output in zip(inputs, outputs): np_input = np.array(case[0]) np_rois = np.array(case[1]) np_output = np.array(output[0]) np_grad = np.array(output[1]) x = torch.tensor(np_input, dtype=dtype, device=device, requires_grad=True) rois = torch.tensor(np_rois, dtype=dtype, device=device) output = roi_align(x, rois, (pool_h, pool_w), spatial_scale, sampling_ratio, 'avg', True) output.backward(torch.ones_like(output)) assert np.allclose(output.data.type(torch.float).cpu().numpy(), np_output, atol=1e-3) assert np.allclose(x.grad.data.type(torch.float).cpu().numpy(), np_grad, atol=1e-3)
def test_roialign(): from mmcv.ops import roi_align # roi align config pool_h = 2 pool_w = 2 spatial_scale = 1.0 sampling_ratio = 2 inputs = [([[[[1., 2.], [3., 4.]]]], [[0., 0., 0., 1., 1.]]), ([[[[1., 2.], [3., 4.]], [[4., 3.], [2., 1.]]]], [[0., 0., 0., 1., 1.]]), ([[[[1., 2., 5., 6.], [3., 4., 7., 8.], [9., 10., 13., 14.], [11., 12., 15., 16.]]]], [[0., 0., 0., 3., 3.]])] def warpped_function(torch_input, torch_rois): return roi_align(torch_input, torch_rois, (pool_w, pool_h), spatial_scale, sampling_ratio, 'avg', True) for case in inputs: np_input = np.array(case[0], dtype=np.float32) np_rois = np.array(case[1], dtype=np.float32) input = torch.from_numpy(np_input) rois = torch.from_numpy(np_rois) # compute pytorch_output with torch.no_grad(): pytorch_output = roi_align(input, rois, (pool_w, pool_h), spatial_scale, sampling_ratio, 'avg', True) # export and load onnx model wrapped_model = WrapFunction(warpped_function) with torch.no_grad(): torch.onnx.export(wrapped_model, (input, rois), onnx_file, export_params=True, keep_initializers_as_inputs=True, input_names=['input', 'rois'], opset_version=11) onnx_model = onnx.load(onnx_file) # compute onnx_output input_all = [node.name for node in onnx_model.graph.input] input_initializer = [ node.name for node in onnx_model.graph.initializer ] net_feed_input = list(set(input_all) - set(input_initializer)) assert (len(net_feed_input) == 2) sess = rt.InferenceSession(onnx_file) onnx_output = sess.run(None, { 'input': input.detach().numpy(), 'rois': rois.detach().numpy() }) onnx_output = onnx_output[0] # allclose os.remove(onnx_file) assert np.allclose(pytorch_output, onnx_output, atol=1e-3)
def crop_image_patch_v2(pos_proposals, pos_assigned_gt_inds, gt_masks): import torch from torch.nn.modules.utils import _pair device = pos_proposals.device num_pos = pos_proposals.size(0) fake_inds = (torch.arange( num_pos, device=device).to(dtype=pos_proposals.dtype)[:, None]) rois = torch.cat([fake_inds, pos_proposals], dim=1) # Nx5 mask_size = _pair(28) rois = rois.to(device=device) gt_masks_th = (torch.from_numpy(gt_masks).to(device).index_select( 0, pos_assigned_gt_inds).to(dtype=rois.dtype)) # Use RoIAlign could apparently accelerate the training (~0.1s/iter) targets = (roi_align(gt_masks_th, rois, mask_size[::-1], 1.0, 0, True).squeeze(1)) return targets
def bbox_feat_extractor(feature_maps, boxes, pool_size): """ feature_maps: size:1*C*h*w boxes: Mx5 float box with (y1, x1, y2, x2) **with normalization** """ # Currently only supports batch_size 1 boxes = boxes[:, [1, 0, 3, 2]] # Crop and Resize # Result: [num_boxes, pool_height, pool_width, channels] box_ind = torch.zeros(boxes.size(0)) # index of bbox in batch if boxes.is_cuda: box_ind = box_ind.cuda() # CropAndResizeFunction needs batch dimension if len(feature_maps.size()) == 3: feature_maps = feature_maps.unsqueeze(0) # make crops: rois = torch.cat([box_ind.unsqueeze(1), boxes], dim=1) pooled_features = roi_align(feature_maps, rois, pool_size) return pooled_features
def warpped_function(torch_input, torch_rois): return roi_align(torch_input, torch_rois, (pool_w, pool_h), spatial_scale, sampling_ratio, 'avg', True)
def test_roialign(): if torch.__version__ == 'parrots': pytest.skip('onnx is not supported in parrots directly') try: from mmcv.ops import roi_align from mmcv.ops import get_onnxruntime_op_path except (ImportError, ModuleNotFoundError): pytest.skip('roi_align op is not successfully compiled') ort_custom_op_path = get_onnxruntime_op_path() # roi align config pool_h = 2 pool_w = 2 spatial_scale = 1.0 sampling_ratio = 2 inputs = [([[[[1., 2.], [3., 4.]]]], [[0., 0., 0., 1., 1.]]), ([[[[1., 2.], [3., 4.]], [[4., 3.], [2., 1.]]]], [[0., 0., 0., 1., 1.]]), ([[[[1., 2., 5., 6.], [3., 4., 7., 8.], [9., 10., 13., 14.], [11., 12., 15., 16.]]]], [[0., 0., 0., 3., 3.]])] def warpped_function(torch_input, torch_rois): return roi_align(torch_input, torch_rois, (pool_w, pool_h), spatial_scale, sampling_ratio, 'avg', True) for case in inputs: np_input = np.array(case[0], dtype=np.float32) np_rois = np.array(case[1], dtype=np.float32) input = torch.from_numpy(np_input) rois = torch.from_numpy(np_rois) # compute pytorch_output with torch.no_grad(): pytorch_output = roi_align(input, rois, (pool_w, pool_h), spatial_scale, sampling_ratio, 'avg', True) # export and load onnx model wrapped_model = WrapFunction(warpped_function) with torch.no_grad(): torch.onnx.export(wrapped_model, (input, rois), onnx_file, export_params=True, keep_initializers_as_inputs=True, input_names=['input', 'rois'], opset_version=11) onnx_model = onnx.load(onnx_file) session_options = rt.SessionOptions() if os.path.exists(ort_custom_op_path): session_options.register_custom_ops_library(ort_custom_op_path) # compute onnx_output input_all = [node.name for node in onnx_model.graph.input] input_initializer = [ node.name for node in onnx_model.graph.initializer ] net_feed_input = list(set(input_all) - set(input_initializer)) assert (len(net_feed_input) == 2) sess = rt.InferenceSession(onnx_file, session_options) onnx_output = sess.run(None, { 'input': input.detach().numpy(), 'rois': rois.detach().numpy() }) onnx_output = onnx_output[0] # allclose os.remove(onnx_file) assert np.allclose(pytorch_output, onnx_output, atol=1e-3)