Пример #1
0
def _test_roialign_allclose(device, dtype):
    if not torch.cuda.is_available() and device == 'cuda':
        pytest.skip('test requires GPU')
    try:
        from mmcv.ops import roi_align
    except ModuleNotFoundError:
        pytest.skip('test requires compilation')
    pool_h = 2
    pool_w = 2
    spatial_scale = 1.0
    sampling_ratio = 2

    for case, output in zip(inputs, outputs):
        np_input = np.array(case[0])
        np_rois = np.array(case[1])
        np_output = np.array(output[0])
        np_grad = np.array(output[1])

        x = torch.tensor(np_input,
                         dtype=dtype,
                         device=device,
                         requires_grad=True)
        rois = torch.tensor(np_rois, dtype=dtype, device=device)

        output = roi_align(x, rois, (pool_h, pool_w), spatial_scale,
                           sampling_ratio, 'avg', True)
        output.backward(torch.ones_like(output))
        assert np.allclose(output.data.type(torch.float).cpu().numpy(),
                           np_output,
                           atol=1e-3)
        assert np.allclose(x.grad.data.type(torch.float).cpu().numpy(),
                           np_grad,
                           atol=1e-3)
Пример #2
0
def test_roialign():
    from mmcv.ops import roi_align

    # roi align config
    pool_h = 2
    pool_w = 2
    spatial_scale = 1.0
    sampling_ratio = 2

    inputs = [([[[[1., 2.], [3., 4.]]]], [[0., 0., 0., 1., 1.]]),
              ([[[[1., 2.], [3., 4.]], [[4., 3.],
                                        [2., 1.]]]], [[0., 0., 0., 1., 1.]]),
              ([[[[1., 2., 5., 6.], [3., 4., 7., 8.], [9., 10., 13., 14.],
                  [11., 12., 15., 16.]]]], [[0., 0., 0., 3., 3.]])]

    def warpped_function(torch_input, torch_rois):
        return roi_align(torch_input, torch_rois, (pool_w, pool_h),
                         spatial_scale, sampling_ratio, 'avg', True)

    for case in inputs:
        np_input = np.array(case[0], dtype=np.float32)
        np_rois = np.array(case[1], dtype=np.float32)
        input = torch.from_numpy(np_input)
        rois = torch.from_numpy(np_rois)

        # compute pytorch_output
        with torch.no_grad():
            pytorch_output = roi_align(input, rois, (pool_w, pool_h),
                                       spatial_scale, sampling_ratio, 'avg',
                                       True)

        # export and load onnx model
        wrapped_model = WrapFunction(warpped_function)
        with torch.no_grad():
            torch.onnx.export(wrapped_model, (input, rois),
                              onnx_file,
                              export_params=True,
                              keep_initializers_as_inputs=True,
                              input_names=['input', 'rois'],
                              opset_version=11)
        onnx_model = onnx.load(onnx_file)

        # compute onnx_output
        input_all = [node.name for node in onnx_model.graph.input]
        input_initializer = [
            node.name for node in onnx_model.graph.initializer
        ]
        net_feed_input = list(set(input_all) - set(input_initializer))
        assert (len(net_feed_input) == 2)
        sess = rt.InferenceSession(onnx_file)
        onnx_output = sess.run(None, {
            'input': input.detach().numpy(),
            'rois': rois.detach().numpy()
        })
        onnx_output = onnx_output[0]

        # allclose
        os.remove(onnx_file)
        assert np.allclose(pytorch_output, onnx_output, atol=1e-3)
Пример #3
0
def crop_image_patch_v2(pos_proposals, pos_assigned_gt_inds, gt_masks):
    import torch
    from torch.nn.modules.utils import _pair
    device = pos_proposals.device
    num_pos = pos_proposals.size(0)
    fake_inds = (torch.arange(
        num_pos, device=device).to(dtype=pos_proposals.dtype)[:, None])
    rois = torch.cat([fake_inds, pos_proposals], dim=1)  # Nx5
    mask_size = _pair(28)
    rois = rois.to(device=device)
    gt_masks_th = (torch.from_numpy(gt_masks).to(device).index_select(
        0, pos_assigned_gt_inds).to(dtype=rois.dtype))
    # Use RoIAlign could apparently accelerate the training (~0.1s/iter)
    targets = (roi_align(gt_masks_th, rois, mask_size[::-1], 1.0, 0,
                         True).squeeze(1))
    return targets
Пример #4
0
def bbox_feat_extractor(feature_maps, boxes, pool_size):
    """
        feature_maps: size:1*C*h*w
        boxes: Mx5 float box with (y1, x1, y2, x2) **with normalization**
    """
    # Currently only supports batch_size 1
    boxes = boxes[:, [1, 0, 3, 2]]

    # Crop and Resize
    # Result: [num_boxes, pool_height, pool_width, channels]
    box_ind = torch.zeros(boxes.size(0))  # index of bbox in batch
    if boxes.is_cuda:
        box_ind = box_ind.cuda()

    # CropAndResizeFunction needs batch dimension
    if len(feature_maps.size()) == 3:
        feature_maps = feature_maps.unsqueeze(0)

    # make crops:
    rois = torch.cat([box_ind.unsqueeze(1), boxes], dim=1)
    pooled_features = roi_align(feature_maps, rois, pool_size)

    return pooled_features
Пример #5
0
 def warpped_function(torch_input, torch_rois):
     return roi_align(torch_input, torch_rois, (pool_w, pool_h),
                      spatial_scale, sampling_ratio, 'avg', True)
Пример #6
0
def test_roialign():
    if torch.__version__ == 'parrots':
        pytest.skip('onnx is not supported in parrots directly')
    try:
        from mmcv.ops import roi_align
        from mmcv.ops import get_onnxruntime_op_path
    except (ImportError, ModuleNotFoundError):
        pytest.skip('roi_align op is not successfully compiled')

    ort_custom_op_path = get_onnxruntime_op_path()
    # roi align config
    pool_h = 2
    pool_w = 2
    spatial_scale = 1.0
    sampling_ratio = 2

    inputs = [([[[[1., 2.], [3., 4.]]]], [[0., 0., 0., 1., 1.]]),
              ([[[[1., 2.], [3., 4.]], [[4., 3.],
                                        [2., 1.]]]], [[0., 0., 0., 1., 1.]]),
              ([[[[1., 2., 5., 6.], [3., 4., 7., 8.], [9., 10., 13., 14.],
                  [11., 12., 15., 16.]]]], [[0., 0., 0., 3., 3.]])]

    def warpped_function(torch_input, torch_rois):
        return roi_align(torch_input, torch_rois, (pool_w, pool_h),
                         spatial_scale, sampling_ratio, 'avg', True)

    for case in inputs:
        np_input = np.array(case[0], dtype=np.float32)
        np_rois = np.array(case[1], dtype=np.float32)
        input = torch.from_numpy(np_input)
        rois = torch.from_numpy(np_rois)

        # compute pytorch_output
        with torch.no_grad():
            pytorch_output = roi_align(input, rois, (pool_w, pool_h),
                                       spatial_scale, sampling_ratio, 'avg',
                                       True)

        # export and load onnx model
        wrapped_model = WrapFunction(warpped_function)
        with torch.no_grad():
            torch.onnx.export(wrapped_model, (input, rois),
                              onnx_file,
                              export_params=True,
                              keep_initializers_as_inputs=True,
                              input_names=['input', 'rois'],
                              opset_version=11)

        onnx_model = onnx.load(onnx_file)
        session_options = rt.SessionOptions()
        if os.path.exists(ort_custom_op_path):
            session_options.register_custom_ops_library(ort_custom_op_path)

        # compute onnx_output
        input_all = [node.name for node in onnx_model.graph.input]
        input_initializer = [
            node.name for node in onnx_model.graph.initializer
        ]
        net_feed_input = list(set(input_all) - set(input_initializer))
        assert (len(net_feed_input) == 2)
        sess = rt.InferenceSession(onnx_file, session_options)
        onnx_output = sess.run(None, {
            'input': input.detach().numpy(),
            'rois': rois.detach().numpy()
        })
        onnx_output = onnx_output[0]

        # allclose
        os.remove(onnx_file)
        assert np.allclose(pytorch_output, onnx_output, atol=1e-3)