Exemple #1
0
    def forward(self, laf: torch.Tensor, img: torch.Tensor) -> torch.Tensor:
        """
        Args:
            laf: shape [BxNx2x3]
            img: shape [Bx1xHxW]

        Returns:
            laf_out, shape [BxNx2x3]
        """
        raise_error_if_laf_is_not_valid(laf)
        img_message: str = "Invalid img shape, we expect BxCxHxW. Got: {}".format(
            img.shape)
        if not isinstance(img, torch.Tensor):
            raise TypeError("img type is not a torch.Tensor. Got {}".format(
                type(img)))
        if len(img.shape) != 4:
            raise ValueError(img_message)
        if laf.size(0) != img.size(0):
            raise ValueError(
                "Batch size of laf and img should be the same. Got {}, {}".
                format(img.size(0), laf.size(0)))
        B, N = laf.shape[:2]
        patches: torch.Tensor = extract_patches_from_pyramid(
            img, laf, self.patch_size).view(-1, 1, self.patch_size,
                                            self.patch_size)
        angles_radians: torch.Tensor = self.angle_detector(patches).view(B, N)
        prev_angle = get_laf_orientation(laf).view_as(angles_radians)
        laf_out: torch.Tensor = set_laf_orientation(
            laf,
            rad2deg(angles_radians) + prev_angle)
        return laf_out
Exemple #2
0
 def forward(self, x: torch.Tensor, mask=None):
     self.affnet = self.affnet.to(x.device)
     max_val = x.max()
     if max_val < 2.0:
         img_np = (255 * K.tensor_to_image(x)).astype(np.uint8)
     else:
         img_np = K.tensor_to_image(x).astype(np.uint8)
     if mask is not None:
         mask = K.tensor_to_image(x).astype(np.uint8)
     kpts = self.features.detect(img_np, mask)
     lafs, resp = laf_from_opencv_kpts(kpts,
                                       mrSize=self.mrSize,
                                       with_resp=True,
                                       device=x.device)
     ori = KF.get_laf_orientation(lafs)
     lafs = self.affnet(lafs, x.mean(dim=1, keepdim=True))
     lafs = KF.set_laf_orientation(lafs, ori)
     return lafs, resp