Ejemplo n.º 1
0
    def test_real_sift_preextract(self, device, dtype, data):
        torch.random.manual_seed(0)
        # This is not unit test, but that is quite good integration test
        feat = SIFTFeature(2000)
        matcher = LocalFeatureMatcher(feat, DescriptorMatcher('snn',
                                                              0.8)).to(device)
        ransac = RANSAC('homography', 1.0, 2048, 10).to(device, dtype)
        data_dev = utils.dict_to(data, device, dtype)
        pts_src = data_dev['pts0']
        pts_dst = data_dev['pts1']

        lafs, _, descs = feat(data_dev["image0"])
        data_dev["lafs0"] = lafs
        data_dev["descriptors0"] = descs

        lafs2, _, descs2 = feat(data_dev["image1"])
        data_dev["lafs1"] = lafs2
        data_dev["descriptors1"] = descs2

        with torch.no_grad():
            out = matcher(data_dev)
        homography, inliers = ransac(out['keypoints0'], out['keypoints1'])
        assert inliers.sum().item() > 50  # we have enough inliers
        # Reprojection error of 5px is OK
        assert_close(transform_points(homography[None], pts_src[None]),
                     pts_dst[None],
                     rtol=5e-2,
                     atol=5)
Ejemplo n.º 2
0
 def test_gradcheck(self, device):
     B, C, H, W = 1, 1, 32, 32
     img = torch.rand(B, C, H, W, device=device)
     local_feature = SIFTFeature(2, True).to(device).to(device)
     img = utils.tensor_to_gradcheck_var(img)  # to var
     assert gradcheck(local_feature,
                      img,
                      eps=1e-4,
                      atol=1e-4,
                      raise_exception=True)
Ejemplo n.º 3
0
    def test_gradcheck(self, device):
        matcher = LocalFeatureMatcher(SIFTFeature(5),
                                      DescriptorMatcher('nn', 1.0)).to(device)
        patches = torch.rand(1, 1, 32, 32, device=device)
        patches05 = resize(patches, (48, 48))
        patches = utils.tensor_to_gradcheck_var(patches)  # to var
        patches05 = utils.tensor_to_gradcheck_var(patches05)  # to var

        def proxy_forward(x, y):
            return matcher({"image0": x, "image1": y})["keypoints0"]

        assert gradcheck(proxy_forward, (patches, patches05),
                         eps=1e-4,
                         atol=1e-4,
                         raise_exception=True)
Ejemplo n.º 4
0
 def test_smoke(self, device):
     matcher = LocalFeatureMatcher(SIFTFeature(5),
                                   DescriptorMatcher('snn', 0.8)).to(device)
     assert matcher is not None
Ejemplo n.º 5
0
 def test_smoke(self, device, dtype):
     sift = SIFTFeature()
     assert sift is not None