예제 #1
0
    def test_real_sift_preextract(self, device, dtype, data):
        torch.random.manual_seed(0)
        # This is not unit test, but that is quite good integration test
        feat = SIFTFeature(2000)
        matcher = LocalFeatureMatcher(feat, DescriptorMatcher('snn',
                                                              0.8)).to(device)
        ransac = RANSAC('homography', 1.0, 2048, 10).to(device, dtype)
        data_dev = utils.dict_to(data, device, dtype)
        pts_src = data_dev['pts0']
        pts_dst = data_dev['pts1']

        lafs, _, descs = feat(data_dev["image0"])
        data_dev["lafs0"] = lafs
        data_dev["descriptors0"] = descs

        lafs2, _, descs2 = feat(data_dev["image1"])
        data_dev["lafs1"] = lafs2
        data_dev["descriptors1"] = descs2

        with torch.no_grad():
            out = matcher(data_dev)
        homography, inliers = ransac(out['keypoints0'], out['keypoints1'])
        assert inliers.sum().item() > 50  # we have enough inliers
        # Reprojection error of 5px is OK
        assert_close(transform_points(homography[None], pts_src[None]),
                     pts_dst[None],
                     rtol=5e-2,
                     atol=5)
예제 #2
0
 def test_nomatch(self, device, dtype, data):
     matcher = LocalFeatureMatcher(GFTTAffNetHardNet(100),
                                   DescriptorMatcher('snn', 0.8)).to(
                                       device, dtype)
     data_dev = utils.dict_to(data, device, dtype)
     with torch.no_grad():
         out = matcher({
             "image0": data_dev["image0"],
             "image1": 0 * data_dev["image0"]
         })
     assert len(out['keypoints0']) == 0
예제 #3
0
    def test_jit(self, device, dtype):
        B, C, H, W = 1, 1, 32, 32
        patches = torch.rand(B, C, H, W, device=device, dtype=dtype)
        patches2x = resize(patches, (48, 48))
        inputs = {"image0": patches, "image1": patches2x}
        model = LocalFeatureMatcher(SIFTDescriptor(32),
                                    DescriptorMatcher('snn',
                                                      0.8)).to(device).eval()
        model_jit = torch.jit.script(model)

        out = model(inputs)
        out_jit = model_jit(inputs)
        for k, v in out.items():
            assert_close(v, out_jit[k])
예제 #4
0
    def test_gradcheck(self, device):
        matcher = LocalFeatureMatcher(SIFTFeature(5),
                                      DescriptorMatcher('nn', 1.0)).to(device)
        patches = torch.rand(1, 1, 32, 32, device=device)
        patches05 = resize(patches, (48, 48))
        patches = utils.tensor_to_gradcheck_var(patches)  # to var
        patches05 = utils.tensor_to_gradcheck_var(patches05)  # to var

        def proxy_forward(x, y):
            return matcher({"image0": x, "image1": y})["keypoints0"]

        assert gradcheck(proxy_forward, (patches, patches05),
                         eps=1e-4,
                         atol=1e-4,
                         raise_exception=True)
예제 #5
0
 def test_real_keynet(self, device, dtype, data):
     torch.random.manual_seed(0)
     # This is not unit test, but that is quite good integration test
     matcher = LocalFeatureMatcher(KeyNetHardNet(500),
                                   DescriptorMatcher('snn', 0.9)).to(
                                       device, dtype)
     ransac = RANSAC('homography', 1.0, 2048, 10).to(device, dtype)
     data_dev = utils.dict_to(data, device, dtype)
     pts_src = data_dev['pts0']
     pts_dst = data_dev['pts1']
     with torch.no_grad():
         out = matcher(data_dev)
     homography, inliers = ransac(out['keypoints0'], out['keypoints1'])
     assert inliers.sum().item() > 50  # we have enough inliers
     # Reprojection error of 5px is OK
     assert_close(transform_points(homography[None], pts_src[None]),
                  pts_dst[None],
                  rtol=5e-2,
                  atol=5)
예제 #6
0
 def test_smoke(self, device):
     matcher = LocalFeatureMatcher(SIFTFeature(5),
                                   DescriptorMatcher('snn', 0.8)).to(device)
     assert matcher is not None