예제 #1
0
    def forward(self, laf: torch.Tensor,
                img: torch.Tensor) -> torch.Tensor:  # type: ignore
        """
        Args:
            laf: (torch.Tensor), shape [BxNx2x3]
            img: (torch.Tensor), shape [Bx1xHxW]

        Returns:
            laf_out: (torch.Tensor), shape [BxNx2x3] """
        raise_error_if_laf_is_not_valid(laf)
        img_message: str = "Invalid img shape, we expect BxCxHxW. Got: {}".format(
            img.shape)
        if not torch.is_tensor(img):
            raise TypeError("img type is not a torch.Tensor. Got {}".format(
                type(img)))
        if len(img.shape) != 4:
            raise ValueError(img_message)
        if laf.size(0) != img.size(0):
            raise ValueError(
                "Batch size of laf and img should be the same. Got {}, {}".
                format(img.size(0), laf.size(0)))
        B, N = laf.shape[:2]
        patches: torch.Tensor = extract_patches_from_pyramid(
            img, laf, self.patch_size).view(-1, 1, self.patch_size,
                                            self.patch_size)
        angles_radians: torch.Tensor = self.angle_detector(patches).view(B, N)
        rotmat: torch.Tensor = angle_to_rotation_matrix(
            rad2deg(angles_radians)).view(B * N, 2, 2)

        laf_out: torch.Tensor = torch.cat([
            torch.bmm(make_upright(laf).view(B * N, 2, 3)[:, :2, :2], rotmat),
            laf.view(B * N, 2, 3)[:, :2, 2:]
        ],
                                          dim=2).view(B, N, 2, 3)
        return laf_out
예제 #2
0
    def forward(self, laf: torch.Tensor, img: torch.Tensor) -> torch.Tensor:
        """
        Args:
            laf: shape [BxNx2x3]
            img: shape [Bx1xHxW]

        Returns:
            laf_out, shape [BxNx2x3]
        """
        raise_error_if_laf_is_not_valid(laf)
        img_message: str = "Invalid img shape, we expect BxCxHxW. Got: {}".format(
            img.shape)
        if not isinstance(img, torch.Tensor):
            raise TypeError("img type is not a torch.Tensor. Got {}".format(
                type(img)))
        if len(img.shape) != 4:
            raise ValueError(img_message)
        if laf.size(0) != img.size(0):
            raise ValueError(
                "Batch size of laf and img should be the same. Got {}, {}".
                format(img.size(0), laf.size(0)))
        B, N = laf.shape[:2]
        patches: torch.Tensor = extract_patches_from_pyramid(
            img, laf, self.patch_size).view(-1, 1, self.patch_size,
                                            self.patch_size)
        angles_radians: torch.Tensor = self.angle_detector(patches).view(B, N)
        prev_angle = get_laf_orientation(laf).view_as(angles_radians)
        laf_out: torch.Tensor = set_laf_orientation(
            laf,
            rad2deg(angles_radians) + prev_angle)
        return laf_out
예제 #3
0
    def forward(self, laf: torch.Tensor, img: torch.Tensor) -> torch.Tensor:
        """
        Args:
            laf: shape [BxNx2x3]
            img: shape [Bx1xHxW]

        Returns:
            laf_out, shape [BxNx2x3]
        """
        raise_error_if_laf_is_not_valid(laf)
        KORNIA_CHECK_SHAPE(img, ["B", "C", "H", "W"])
        if laf.size(0) != img.size(0):
            raise ValueError(f"Batch size of laf and img should be the same. Got {img.size(0)}, {laf.size(0)}")
        B, N = laf.shape[:2]
        patches: torch.Tensor = extract_patches_from_pyramid(img, laf, self.patch_size).view(
            -1, 1, self.patch_size, self.patch_size
        )
        angles_radians: torch.Tensor = self.angle_detector(patches).view(B, N)
        prev_angle = get_laf_orientation(laf).view_as(angles_radians)
        laf_out: torch.Tensor = set_laf_orientation(laf, rad2deg(angles_radians) + prev_angle)
        return laf_out