Пример #1
0
    def detect_and_describe(self,
                            image: Image) -> Tuple[Keypoints, np.ndarray]:
        """Jointly generate keypoint detections and their associated descriptors from a single image."""
        # TODO(ayushbaid): fix inference issue #110
        device = torch.device("cuda" if self._use_cuda else "cpu")
        model = SuperPoint(self._config).to(device)
        model.eval()

        # Compute features.
        image_tensor = torch.from_numpy(
            np.expand_dims(
                image_utils.rgb_to_gray_cv(image).value_array.astype(
                    np.float32) / 255.0, (0, 1))).to(device)
        with torch.no_grad():
            model_results = model({"image": image_tensor})
        torch.cuda.empty_cache()

        # Unpack results.
        coordinates = model_results["keypoints"][0].detach().cpu().numpy()
        scores = model_results["scores"][0].detach().cpu().numpy()
        keypoints = Keypoints(coordinates, scales=None, responses=scores)
        descriptors = model_results["descriptors"][0].detach().cpu().numpy().T

        # Filter features.
        if image.mask is not None:
            keypoints, valid_idxs = keypoints.filter_by_mask(image.mask)
            descriptors = descriptors[valid_idxs]
        keypoints, selection_idxs = keypoints.get_top_k(self.max_keypoints)
        descriptors = descriptors[selection_idxs]

        return keypoints, descriptors
Пример #2
0
    def test_filter_by_mask(self) -> None:
        """Test the `filter_by_mask` method."""
        # Create a (9, 9) mask with ones in a (5, 5) square in the center of the mask and zeros everywhere else.
        mask = np.zeros((9, 9)).astype(np.uint8)
        mask[2:7, 2:7] = 1

        # Test coordinates near corners of square of ones and along the diagonal.
        coordinates = np.array([
            [1.4, 1.4],
            [1.4, 6.4],
            [6.4, 1.4],
            [6.4, 6.4],
            [5.0, 5.0],
            [0.0, 0.0],
            [8.0, 8.0],
        ])
        input_keypoints = Keypoints(coordinates=coordinates)
        expected_keypoints = Keypoints(coordinates=coordinates[[3, 4]])

        # Create keypoints from coordinates and dummy descriptors.
        filtered_keypoints, _ = input_keypoints.filter_by_mask(mask)
        assert len(filtered_keypoints) == 2
        self.assertEqual(filtered_keypoints, expected_keypoints)