Пример #1
0
    def forward(self, laf: torch.Tensor,
                img: torch.Tensor) -> torch.Tensor:  # type: ignore
        """
        Args:
            laf: (torch.Tensor) shape [BxNx2x3]
            img: (torch.Tensor) shape [Bx1xHxW]

        Returns:
            laf_out: (torch.Tensor) shape [BxNx2x3]"""
        raise_error_if_laf_is_not_valid(laf)
        img_message: str = "Invalid img shape, we expect BxCxHxW. Got: {}".format(
            img.shape)
        if not torch.is_tensor(img):
            raise TypeError("img type is not a torch.Tensor. Got {}".format(
                type(img)))
        if len(img.shape) != 4:
            raise ValueError(img_message)
        if laf.size(0) != img.size(0):
            raise ValueError(
                "Batch size of laf and img should be the same. Got {}, {}".
                format(img.size(0), laf.size(0)))
        B, N = laf.shape[:2]
        PS: int = self.patch_size
        patches: torch.Tensor = extract_patches_from_pyramid(
            img, make_upright(laf), PS, True).view(-1, 1, PS, PS)
        ellipse_shape: torch.Tensor = self.affine_shape_detector(patches)
        ellipses = torch.cat(
            [laf.view(-1, 2, 3)[..., 2].unsqueeze(1), ellipse_shape],
            dim=2).view(B, N, 5)
        scale_orig = get_laf_scale(laf)
        laf_out = ellipse_to_laf(ellipses)
        ellipse_scale = get_laf_scale(laf_out)
        laf_out = scale_laf(laf_out, scale_orig / ellipse_scale)
        return laf_out
Пример #2
0
    def forward(self, laf: torch.Tensor, img: torch.Tensor) -> torch.Tensor:
        """
        Args:
            laf: shape [BxNx2x3]
            img: shape [Bx1xHxW]

        Returns:
            laf_out shape [BxNx2x3]
        """
        raise_error_if_laf_is_not_valid(laf)
        img_message: str = "Invalid img shape, we expect BxCxHxW. Got: {}".format(img.shape)
        if not torch.is_tensor(img):
            raise TypeError("img type is not a torch.Tensor. Got {}".format(type(img)))
        if len(img.shape) != 4:
            raise ValueError(img_message)
        if laf.size(0) != img.size(0):
            raise ValueError(
                "Batch size of laf and img should be the same. Got {}, {}".format(img.size(0), laf.size(0))
            )
        B, N = laf.shape[:2]
        PS: int = self.patch_size
        patches: torch.Tensor = extract_patches_from_pyramid(img, make_upright(laf), PS, True).view(-1, 1, PS, PS)
        xy = self.features(self._normalize_input(patches)).view(-1, 3)
        a1 = torch.cat([1.0 + xy[:, 0].reshape(-1, 1, 1), 0 * xy[:, 0].reshape(-1, 1, 1)], dim=2)
        a2 = torch.cat([xy[:, 1].reshape(-1, 1, 1), 1.0 + xy[:, 2].reshape(-1, 1, 1)], dim=2)
        new_laf_no_center = torch.cat([a1, a2], dim=1).reshape(B, N, 2, 2)
        new_laf = torch.cat([new_laf_no_center, laf[:, :, :, 2:3]], dim=3)
        scale_orig = get_laf_scale(laf)
        ellipse_scale = get_laf_scale(new_laf)
        laf_out = scale_laf(make_upright(new_laf), scale_orig / ellipse_scale)
        return laf_out